diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 67317f012..ac7892e37 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -37,6 +37,7 @@ ForwardMode, ) from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 +from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN from tokenspeed.runtime.multimodal.inputs import maybe_substitute_mm_pad from tokenspeed.runtime.utils import get_colorful_logger from tokenspeed.runtime.utils.nvtx import nvtx_range @@ -218,10 +219,13 @@ def _run_first_step( draft_input, bs, draft_input.input_num_tokens ) input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id) - # Llama Eagle3 narrows for any non-idle catch-up (EXTEND/MIXED/ - # TARGET_VERIFY/DECODE); Qwen/DeepSeek keep is_decode() only. + # Llama Eagle3 and Qwen3.5 NextN narrow for any non-idle catch-up + # (EXTEND/MIXED/TARGET_VERIFY/DECODE); DeepSeek keeps is_decode() only. draft_first_step_reduce = forward_mode.is_decode() or ( - isinstance(self.draft_model_runner.model, LlamaForCausalLMEagle3) + isinstance( + self.draft_model_runner.model, + (LlamaForCausalLMEagle3, Qwen3_5ForConditionalGenerationNextN), + ) and not forward_mode.is_idle() ) diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index df9276488..39a64a329 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -46,6 +46,8 @@ from tokenspeed.runtime.execution.types import ModelExecutionResult from tokenspeed.runtime.grammar.capturable_grammar import setup_grammar_step from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput +from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 +from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN from tokenspeed.runtime.sampling.backends.base import SamplingBackend from tokenspeed.runtime.sampling.dp_sampling_config import ( DpSamplingRuntimeLimits, @@ -1078,7 +1080,22 @@ def execute_idle_forward( global_num_tokens=draft_global_num_tokens, global_bs=global_bs, all_decode_or_idle=all_decode_or_idle, - draft_first_step_reduce=(step_idx == 0 and all_decode_or_idle), + # Mirror the active-rank broaden in eagle.py: Llama Eagle3 + # and Qwen3.5 NextN narrow for any non-idle catch-up, so the + # idle peer must size collectives the same way. + draft_first_step_reduce=( + step_idx == 0 + and ( + all_decode_or_idle + or isinstance( + self.drafter.draft_model_runner.model, + ( + LlamaForCausalLMEagle3, + Qwen3_5ForConditionalGenerationNextN, + ), + ) + ) + ), ) self.drafter.draft_model_runner.forward( draft_ctx, diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index 6bc84a1a9..fa95c3025 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -127,6 +127,35 @@ def configure_runtime(self, **kwargs) -> None: def register_step_counter(self, step_counter: StepCounter): self.step_counter = step_counter + @contextmanager + def record_pd_cache_step( + self, + forward_mode: ForwardMode, + save_kv_cache: bool, + record_kv_cache: bool | None, + ): + """Anchor the PD layerwise cache-step record to the wrapped KV write. + + Records the ``StepCounter`` step before the attention call when the KV + was pre-written (``save_kv_cache=False``) and after it otherwise, so a + layerwise cache transfer always observes a fully written layer. See + ``forward`` for the ``record_kv_cache`` override contract. No-op when no + step counter is registered. Backends that own the record (e.g. the + hybrid wrapper, which counts once per model layer across full-attn + + mamba children) reuse this to avoid duplicating the gate logic. + """ + if record_kv_cache is None: + record_cache = not forward_mode.is_decode() and not forward_mode.is_idle() + else: + record_cache = record_kv_cache + record_cache = record_cache and getattr(self, "step_counter", None) is not None + + if record_cache and not save_kv_cache: + self.step_counter.record_cache() + yield + if record_cache and save_kv_cache: + self.step_counter.record_cache() + def forward( self, q: torch.Tensor, @@ -138,49 +167,42 @@ def forward( forward_mode: ForwardMode, bs: int, save_kv_cache: bool = True, + record_kv_cache: bool | None = None, **kwargs, ): - """Run forward on an attention layer with explicit scheduler metadata.""" - if forward_mode.is_decode(): - return self.forward_decode( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - **kwargs, - ) - else: - if ( - not forward_mode.is_idle() - and getattr(self, "step_counter", None) - and not save_kv_cache - ): - self.step_counter.record_cache() - - ret = self.forward_extend( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - forward_mode=forward_mode, - **kwargs, - ) - - if ( - not forward_mode.is_idle() - and getattr(self, "step_counter", None) - and save_kv_cache - ): - self.step_counter.record_cache() - return ret + """Run forward on an attention layer with explicit scheduler metadata. + + ``record_kv_cache`` overrides the PD layerwise cache-step recording: + ``None`` keeps the default (record on the EXTEND-side path), an explicit + bool forces it so a DECODE-dispatched draft catch-up can still record. + """ + with self.record_pd_cache_step(forward_mode, save_kv_cache, record_kv_cache): + if forward_mode.is_decode(): + ret = self.forward_decode( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + ret = self.forward_extend( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + forward_mode=forward_mode, + **kwargs, + ) + return ret def forward_decode( self, diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index d0c5af50e..b1158a54a 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -1263,6 +1263,7 @@ def forward( forward_mode: ForwardMode, bs: int, save_kv_cache: bool = True, + record_kv_cache: bool | None = None, **kwargs, ): if forward_mode is None: @@ -1276,6 +1277,7 @@ def forward( forward_mode, bs, save_kv_cache, + record_kv_cache=record_kv_cache, **kwargs, ) @@ -1287,45 +1289,36 @@ def forward( layer_id = layer.layer_id if layer else kwargs["layer_id"] backend = self._backend_for_layer(layer_id) - if forward_mode.is_decode(): - return backend.forward_decode( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - **kwargs, - ) - else: - step_counter = getattr(self, "step_counter", None) - if ( - not forward_mode.is_idle() - and step_counter is not None - and not save_kv_cache - ): - step_counter.record_cache() - ret = backend.forward_extend( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - forward_mode=forward_mode, - **kwargs, - ) - if ( - not forward_mode.is_idle() - and step_counter is not None - and save_kv_cache - ): - step_counter.record_cache() - return ret + # See AttentionBackend.forward for the record_kv_cache contract; the step + # is recorded in this wrapper (not the child backends) to keep one step + # per model layer across full-attn + mamba. Idle already returned above. + with self.record_pd_cache_step(forward_mode, save_kv_cache, record_kv_cache): + if forward_mode.is_decode(): + ret = backend.forward_decode( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + ret = backend.forward_extend( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + forward_mode=forward_mode, + **kwargs, + ) + return ret def forward_decode( self, q, k, v, layer, out_cache_loc, token_to_kv_pool, bs, **kwargs diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 258f7c703..54d3c92b1 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -99,6 +99,9 @@ def _attn( q_rope = self._fused_rope_kv_write( positions, q, k, fused_kv_arg ).index_select(0, ctx.gather_ids) + # record_kv_cache (keyed off the real mode) forces the backend's + # PD layerwise cache-step record that the DECODE dispatch would + # otherwise skip on an EXTEND/MIXED catch-up. return ctx.attn_backend.forward( q_rope, None, @@ -109,6 +112,7 @@ def _attn( ForwardMode.DECODE, ctx.bs, save_kv_cache=False, + record_kv_cache=not ctx.forward_mode.is_decode_or_idle(), ) q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( @@ -126,7 +130,7 @@ def _apply_correction(self, ctx: ForwardContext) -> None: correction = ( ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] ).to(seq_lens_buf.dtype) - seq_lens_buf[num_extends : ctx.bs].sub_(correction) + seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1) # --------------------------------------------------------------------------- diff --git a/python/tokenspeed/runtime/models/qwen3_5.py b/python/tokenspeed/runtime/models/qwen3_5.py index 74d1b7b5f..5841e81b2 100644 --- a/python/tokenspeed/runtime/models/qwen3_5.py +++ b/python/tokenspeed/runtime/models/qwen3_5.py @@ -710,16 +710,13 @@ def _apply_qk_norm( self.q_norm.variance_epsilon, ) - def self_attention( + def _project_qkv_rope( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ctx: ForwardContext, - out_cache_loc: torch.Tensor, - ) -> torch.Tensor: - """Full attention forward pass.""" + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """qkv_proj + split + rope (+ optional gate). ``gate`` is ``None`` when ``attn_output_gate=False``.""" qkv, _ = self.qkv_proj(hidden_states) - if self.attn_output_gate: q_gate, k, v = qkv.split( [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 @@ -737,23 +734,48 @@ def self_attention( self.head_dim, self.rotary_emb.rotary_dim, ) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + return q, k, v, gate + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + return q, k, v, None + def _attn( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gate: torch.Tensor | None, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """Backend attention call + optional gate apply. Subclasses override.""" attn_output = self.attn(q, k, v, ctx, out_cache_loc) - - if self.attn_output_gate: + if gate is not None: sigmoid_mul(attn_output, gate) + return attn_output - if ctx.draft_first_step_reduce: - # Slice attn_output to [bs, H] so o_proj runs on live rows only. - attn_output = attn_output.index_select(0, ctx.gather_ids) - + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """Full attention forward pass.""" + q, k, v, gate = self._project_qkv_rope(positions, hidden_states) + attn_output = self._attn(q, k, v, gate, ctx, out_cache_loc) output, _ = self.o_proj(attn_output) return output + def _maybe_narrow_residual( + self, + residual: torch.Tensor, + ctx: ForwardContext, + ) -> torch.Tensor: + """Hook: subclasses narrow residual to match a sliced attn output.""" + return residual + def forward( self, positions: torch.Tensor, @@ -778,9 +800,7 @@ def forward( ctx=ctx, out_cache_loc=out_cache_loc, ) - if ctx.draft_first_step_reduce: - # Gather residual to self_attention's [bs, H]. - residual = residual.index_select(0, ctx.gather_ids) + residual = self._maybe_narrow_residual(residual, ctx) hidden_states, residual = self.comm_manager.post_attn_reduce_norm( hidden_states, residual, ctx ) @@ -816,15 +836,12 @@ def forward_mlp( return hidden_states -ALL_DECODER_LAYER_TYPES = { - "attention": Qwen3_5AttentionDecoderLayer, - "linear_attention": Qwen3_5LinearDecoderLayer, -} - - class Qwen3_5ForCausalLM(nn.Module): """Qwen3.5 Model with support for dense variant.""" + ATTENTION_LAYER_CLS: type = Qwen3_5AttentionDecoderLayer + LINEAR_LAYER_CLS: type = Qwen3_5LinearDecoderLayer + def __init__( self, config: Qwen3_5TextConfig, @@ -849,10 +866,14 @@ def __init__( tp_group=self.mapping.attn.tp_group, ) - # Decoder layers + layer_cls_by_type = { + "attention": self.ATTENTION_LAYER_CLS, + "linear_attention": self.LINEAR_LAYER_CLS, + } + def get_layer(idx: int, prefix: str): layer_type = config.layers_block_type[idx] - layer_class = ALL_DECODER_LAYER_TYPES[layer_type] + layer_class = layer_cls_by_type[layer_type] if layer_type == "attention": prefix = add_prefix("self_attn", prefix) else: diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 52c485e50..9d56a7b38 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -22,13 +22,16 @@ import logging from collections.abc import Iterable +from dataclasses import replace import torch +from tokenspeed_kernel.ops.activation.triton import sigmoid_mul from torch import nn from transformers import PretrainedConfig from tokenspeed.runtime.distributed.mapping import Mapping from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode from tokenspeed.runtime.layers.layernorm import GemmaRMSNorm from tokenspeed.runtime.layers.linear import ReplicatedLinear from tokenspeed.runtime.layers.logits_processor import LogitsMetadata, LogitsProcessor @@ -38,12 +41,110 @@ ) from tokenspeed.runtime.layers.vocab_parallel_embedding import ParallelLMHead from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader -from tokenspeed.runtime.models.qwen3_5 import Qwen3_5ForCausalLM +from tokenspeed.runtime.models.qwen3_5 import ( + Qwen3_5AttentionDecoderLayer, + Qwen3_5ForCausalLM, +) from tokenspeed.runtime.utils import add_prefix logger = logging.getLogger(__name__) +class Qwen3_5DraftAttentionDecoderLayer(Qwen3_5AttentionDecoderLayer): + """NextN draft variant: skip dead catch-up rows on the first draft step. + + On the first draft step the backend runs in DECODE mode with ``q`` sliced + to ``bs`` while ``self.attn`` still writes the full ``N`` rope-d KV rows + from the just-drafted tokens. Multi-step decode delegates to base. + + MIXED catch-up requires a backend that populates a decode-slot metadata + under EXTEND/MIXED at draft init (e.g. trtllm-mha); MHA-family backends + that assert ``not is_mixed()`` at metadata init are not supported. + """ + + def _attn( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gate: torch.Tensor | None, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + if ctx.accept_lengths is None: + return super()._attn(q, k, v, gate, ctx, out_cache_loc) + + self._apply_correction(ctx) + q = q.index_select(0, ctx.gather_ids) + if gate is not None: + gate = gate.index_select(0, ctx.gather_ids) + # Dispatch as DECODE over the sliced live rows via self.attn (see the + # class docstring), which keeps the standard k/v reshape and KV write. + # A ctx copy overrides only the forward mode; record_kv_cache (keyed off + # the real mode) forces the backend's PD layerwise cache-step record that + # DECODE would otherwise skip on an EXTEND/MIXED catch-up. + decode_ctx = replace(ctx, forward_mode=ForwardMode.DECODE) + attn_output = self.attn( + q, + k, + v, + decode_ctx, + out_cache_loc, + record_kv_cache=not ctx.forward_mode.is_decode_or_idle(), + ) + if gate is not None: + sigmoid_mul(attn_output, gate) + return attn_output + + def _apply_correction(self, ctx: ForwardContext) -> None: + """Trim decode rows' cache_seqlens by ``spec_num_tokens - accept_lengths``.""" + seq_lens_buf = ctx.draft_seq_lens_buf + if seq_lens_buf is None or ctx.accept_lengths is None: + return + num_extends = ctx.num_extends + if num_extends >= ctx.bs: + return + correction = ( + ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] + ).to(seq_lens_buf.dtype) + seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1) + + def _maybe_narrow_residual( + self, + residual: torch.Tensor, + ctx: ForwardContext, + ) -> torch.Tensor: + if ctx.accept_lengths is None or ctx.forward_mode.is_idle(): + return residual + return residual.index_select(0, ctx.gather_ids) + + +class Qwen3_5DraftForCausalLM(Qwen3_5ForCausalLM): + """Causal LM with the draft-variant attention layer injected. + + Restricted to single-layer drafts: ``_apply_correction`` mutates + ``ctx.draft_seq_lens_buf`` in place and is not idempotent across layers. + A multi-layer draft would double-trim cache_seqlens. Lift the correction + out of the per-layer hook (e.g. into the drafter) before relaxing this. + """ + + ATTENTION_LAYER_CLS: type = Qwen3_5DraftAttentionDecoderLayer + + def __init__( + self, + config, + mapping, + quant_config=None, + prefix: str = "", + ) -> None: + assert config.num_hidden_layers == 1, ( + "Qwen3_5DraftForCausalLM requires num_hidden_layers == 1 " + f"(got {config.num_hidden_layers}); _apply_correction is not " + "idempotent across layers." + ) + super().__init__(config, mapping, quant_config=quant_config, prefix=prefix) + + class Qwen3_5ForConditionalGenerationNextN(nn.Module): def __init__( self, @@ -74,7 +175,7 @@ def __init__( self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps) config.num_hidden_layers = 1 config.full_attention_interval = 1 - self.model = Qwen3_5ForCausalLM( + self.model = Qwen3_5DraftForCausalLM( config, mapping=self.mapping, quant_config=quant_config,