From d7882e7b54fabd1c1412a71b6bb4b743180ef4bc Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 00:24:52 +0000 Subject: [PATCH 01/20] refactor(spec-decode): wrap Eagle3 attention via base llama._attn Signed-off-by: rjzhb --- .../tokenspeed/runtime/execution/context.py | 9 +- .../runtime/execution/drafter/eagle.py | 17 +- .../runtime/models/base/causal_lm.py | 4 - python/tokenspeed/runtime/models/llama.py | 69 +++++- .../tokenspeed/runtime/models/llama_eagle3.py | 227 ++++++------------ 5 files changed, 143 insertions(+), 183 deletions(-) diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index 194a9709d..67a889e4d 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -50,7 +50,8 @@ class ForwardContext: forward_mode: ForwardMode | None req_to_page: torch.Tensor | None = None capture_hidden_mode: CaptureHiddenMode | None = CaptureHiddenMode.NULL - # Spec decode draft head's first step prunes to one live row per request. + # Legacy draft first-step flag; Qwen / DeepSeek NextN still set this until + # their attention subclasses own trim. Llama Eagle3 uses accept_lengths. draft_first_step_reduce: bool = False # Normalized explicit decode input overrides for this forward, if any. decode_input_ids: list[int] | None = None @@ -62,3 +63,9 @@ class ForwardContext: # --- logits processor --- gather_ids: torch.Tensor | None = None + + # --- spec-decode draft (drafter-owned buffers plumbed per forward) --- + # draft_seq_lens_buf: mutable per-request seq_lens alias the draft backend reads. + draft_seq_lens_buf: torch.Tensor | None = None + # accept_lengths: per-request accepted verify width for cache_seqlens correction. + accept_lengths: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 0aab7a1d9..6466c8290 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -36,7 +36,6 @@ CaptureHiddenMode, ForwardMode, ) -from tokenspeed.runtime.models.base import BaseCausalLM from tokenspeed.runtime.multimodal.inputs import maybe_substitute_mm_pad from tokenspeed.runtime.utils import get_colorful_logger from tokenspeed.runtime.utils.nvtx import nvtx_range @@ -221,20 +220,6 @@ def _run_first_step( input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id) draft_first_step_reduce = forward_mode.is_decode() - # TODO: remove the isinstance/flag gate together with pre_attention_trim - # once Qwen NextN and DeepSeek V3 NextN also pre-slice q. - draft_model = self.draft_model_runner.model - if ( - draft_first_step_reduce - and self.attn_backend.support_kv_cache_prewrite - and isinstance(draft_model, BaseCausalLM) - and draft_model.pre_attention_trim - ): - correction = (self.spec_num_tokens - draft_input.accept_lengths).to( - self.draft_seq_lens_buf.dtype - ) - self.draft_seq_lens_buf[:bs].sub_(correction) - draft_first_mode = ( ForwardMode.DRAFT_EXTEND if forward_mode.is_target_verify() @@ -255,6 +240,8 @@ def _run_first_step( global_bs=draft_input.global_bs, all_decode_or_idle=draft_input.all_decode_or_idle, draft_first_step_reduce=draft_first_step_reduce, + draft_seq_lens_buf=self.draft_seq_lens_buf, + accept_lengths=draft_input.accept_lengths, ) return self.draft_model_runner.forward( diff --git a/python/tokenspeed/runtime/models/base/causal_lm.py b/python/tokenspeed/runtime/models/base/causal_lm.py index b87e40e96..b83ce4be0 100644 --- a/python/tokenspeed/runtime/models/base/causal_lm.py +++ b/python/tokenspeed/runtime/models/base/causal_lm.py @@ -44,10 +44,6 @@ class BaseCausalLM(nn.Module): model_cls: type[BaseTransformerModel] - # TODO: temporary; remove in the follow-up refactoring that extends - # pre-attn q-slice to Qwen NextN and DeepSeek V3 NextN. - pre_attention_trim: bool = False - def __init__( self, config: PretrainedConfig, diff --git a/python/tokenspeed/runtime/models/llama.py b/python/tokenspeed/runtime/models/llama.py index 1baf785bf..0fc39e1f1 100644 --- a/python/tokenspeed/runtime/models/llama.py +++ b/python/tokenspeed/runtime/models/llama.py @@ -52,7 +52,9 @@ BaseDecoderLayer, BaseTransformerModel, ) +from tokenspeed.runtime.models.utils import create_fused_set_kv_buffer_arg from tokenspeed.runtime.utils import add_prefix +from tokenspeed.runtime.utils.pdl import pdl_enabled class LlamaMLP(nn.Module): @@ -119,6 +121,7 @@ def __init__( layer_id: int = 0, quant_config: QuantizationConfig | None = None, prefix: str = "", + qkv_input_size: int | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -151,7 +154,7 @@ def __init__( attention_bias = getattr(config, "attention_bias", False) self.qkv_proj = QKVParallelLinear( - hidden_size, + qkv_input_size or hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, @@ -199,14 +202,72 @@ def forward( # the batch row is empty (e.g. idle ranks under DP attention). Matches # the short-circuit ``LlamaMLP.forward`` already has. if hidden_states.shape[0] == 0: - return hidden_states + return hidden_states.new_zeros( + (0, self.hidden_size), dtype=hidden_states.dtype + ) qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc) + attn_output = self._attn(positions, q, k, v, ctx, out_cache_loc) output, _ = self.o_proj(attn_output) return output + def _attn( + self, + positions: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """RoPE + attention (pre-o_proj), with optional fused KV pre-write. + + When the backend supports KV pre-write, fused rope writes KV directly + into the cache so the attention call can run with ``save_kv_cache=False`` + (saves one kernel launch). Subclasses (e.g. Eagle3 draft head) override + this hook to insert spec-decode behaviour around the same scaffolding. + """ + if ctx.attn_backend.support_kv_cache_prewrite(ctx.forward_mode): + q_rope = self._fused_rope_kv_write(positions, q, k, v, ctx, out_cache_loc) + return self.attn( + q_rope, + None, + None, + save_kv_cache=False, + ctx=ctx, + out_cache_loc=out_cache_loc, + ) + q, k = self.rotary_emb(positions, q, k) + return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc) + + def _fused_rope_kv_write( + self, + positions: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """Fused RoPE that writes KV into cache and returns the rope'd Q.""" + n = q.shape[0] + fused_kv_arg = create_fused_set_kv_buffer_arg( + value=v.view(n, self.num_kv_heads, self.head_dim), + layer=self.attn, + out_cache_loc=out_cache_loc, + token_to_kv_pool=ctx.token_to_kv_pool, + ) + q_rope = torch.empty((n, self.q_size), dtype=q.dtype, device=q.device) + self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=fused_kv_arg, + output_q_rope=q_rope, + enable_pdl=pdl_enabled(), + ) + return q_rope + class LlamaDecoderLayer(BaseDecoderLayer): diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 9c160f53b..b633bfbb8 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -32,7 +32,6 @@ from torch import nn from transformers import LlamaConfig -from tokenspeed.runtime.configs.utils import get_rope_theta from tokenspeed.runtime.distributed.mapping import Mapping from tokenspeed.runtime.execution.context import ForwardContext from tokenspeed.runtime.execution.forward_batch_info import ForwardMode @@ -42,12 +41,9 @@ from tokenspeed.runtime.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear, ) -from tokenspeed.runtime.layers.paged_attention import PagedAttention from tokenspeed.runtime.layers.quantization.base_config import QuantizationConfig -from tokenspeed.runtime.layers.rotary_embedding import get_rope from tokenspeed.runtime.layers.vocab_parallel_embedding import ParallelLMHead from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader from tokenspeed.runtime.models.base import ( @@ -55,9 +51,8 @@ BaseDecoderLayer, BaseTransformerModel, ) -from tokenspeed.runtime.models.utils import create_fused_set_kv_buffer_arg +from tokenspeed.runtime.models.llama import LlamaAttention as BaseLlamaAttention from tokenspeed.runtime.utils import add_prefix, get_colorful_logger -from tokenspeed.runtime.utils.pdl import pdl_enabled logger = get_colorful_logger(__name__) @@ -67,160 +62,72 @@ # --------------------------------------------------------------------------- -class LlamaAttention(nn.Module): +class LlamaAttention(BaseLlamaAttention): + """Eagle3 draft head attention. - def __init__( - self, - config: LlamaConfig, - mapping: Mapping, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - layer_id: int = 0, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - qkv_input_size: int | None = None, - ) -> None: + Inherits ``__init__`` (with ``qkv_input_size=2*hidden_size`` to accommodate + the [embed || hidden] concat) and ``forward`` (= empty-shape handling + + qkv projection + o_proj scaffolding) from base. - super().__init__() - self.hidden_size = hidden_size - - self.attn_tp_size = mapping.attn.tp_size - self.attn_tp_rank = mapping.attn.tp_rank - attn_tp_group = mapping.attn.tp_group - - self.total_num_heads = num_heads - assert self.total_num_heads % self.attn_tp_size == 0 - self.num_heads = self.total_num_heads // self.attn_tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= self.attn_tp_size: - assert self.total_num_kv_heads % self.attn_tp_size == 0 - else: - assert self.attn_tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size) - self.head_dim = getattr( - config, "head_dim", self.hidden_size // self.total_num_heads - ) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - rope_theta = get_rope_theta(config) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - - self.qkv_proj = QKVParallelLinear( - qkv_input_size or hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - tp_group=attn_tp_group, - prefix=add_prefix("qkv_proj", prefix), - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=False, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - tp_group=attn_tp_group, - prefix=add_prefix("o_proj", prefix), - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention( - self.num_heads, - self.head_dim, - self.head_dim**-0.5, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id, - ) + Overrides ``_attn`` to add the draft-first-step dispatch B: trim + cache_seqlens, slice q to one live row per request, and route the attention + call as ``DECODE`` (fused KV pre-write path) or post-slice attn_output + (non-fused fallback). Inactive draft steps delegate to base. + """ - def forward( + def _attn( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, ) -> torch.Tensor: - - if hidden_states.shape[0] == 0: - # Under DP attention the caller concatenates [embeds, hidden_states] - # to width 2*H before attention. Peers with N>0 return an H-wide - # tensor from o_proj; idle ranks must match that invariant so the - # subsequent dense-TP RSAG agrees on the last dim. - return hidden_states.new_zeros( - (0, self.hidden_size), dtype=hidden_states.dtype - ) - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - fused_kv_arg = None + # Active draft first step (= target verify was pure decode + drafter set + # up gather_ids). Other forwards (multi-step decode without gather_ids, + # or first step after a target verify with extends) delegate to base. + if ctx.gather_ids is None or not ctx.forward_mode.is_decode(): + return super()._attn(positions, q, k, v, ctx, out_cache_loc) + + # Active dispatch B: correction + q-slice + DECODE (fused) or post-slice (non-fused). + self._apply_correction(ctx) if ctx.attn_backend.support_kv_cache_prewrite(ctx.forward_mode): - n = q.shape[0] - v_3d = v.view(n, self.num_kv_heads, self.head_dim) - fused_kv_arg = create_fused_set_kv_buffer_arg( - value=v_3d, - layer=self.attn, - out_cache_loc=out_cache_loc, - token_to_kv_pool=ctx.token_to_kv_pool, + q_rope = self._fused_rope_kv_write(positions, q, k, v, ctx, out_cache_loc) + q_rope = q_rope.index_select(0, ctx.gather_ids) + return ctx.attn_backend.forward( + q_rope, + None, + None, + self.attn, + out_cache_loc, + ctx.token_to_kv_pool, + ForwardMode.DECODE, + ctx.bs, + save_kv_cache=False, ) + q, k = self.rotary_emb(positions, q, k) + return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( + 0, ctx.gather_ids + ) - if fused_kv_arg is not None: - n = q.shape[0] - q_rope = torch.empty((n, self.q_size), dtype=q.dtype, device=q.device) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=fused_kv_arg, - output_q_rope=q_rope, - enable_pdl=pdl_enabled(), - ) - if ctx.draft_first_step_reduce: - # KV already written via fused_set_kv_buffer_arg above; slice Q - # to one query per request and route attn as decode. - q_rope = q_rope.index_select(0, ctx.gather_ids) - attn_output = ctx.attn_backend.forward( - q_rope, - None, - None, - self.attn, - out_cache_loc, - ctx.token_to_kv_pool, - ForwardMode.DECODE, - ctx.bs, - save_kv_cache=False, - ) - else: - attn_output = self.attn( - q_rope, - None, - None, - save_kv_cache=False, - ctx=ctx, - out_cache_loc=out_cache_loc, - ) - else: - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc) - if ctx.draft_first_step_reduce: - # KV written by self.attn above; slice attn_output so o_proj - # and the rest of the layer only run on the live rows. - attn_output = attn_output.index_select(0, ctx.gather_ids) + def _apply_correction(self, ctx: ForwardContext) -> None: + """Trim cache_seqlens by ``spec_num_tokens - accept_lengths``. - output, _ = self.o_proj(attn_output) - return output + Idempotent across draft layers via ``ctx.accept_lengths = None``. + No-op when drafter did not populate ``accept_lengths`` (e.g. backends + without KV pre-write). + """ + if ctx.accept_lengths is None: + return + seq_lens_buf = ctx.draft_seq_lens_buf + if seq_lens_buf is None: + return + correction = (ctx.attn_backend.spec_num_tokens - ctx.accept_lengths).to( + seq_lens_buf.dtype + ) + seq_lens_buf[: ctx.bs].sub_(correction) + ctx.accept_lengths = None # --------------------------------------------------------------------------- @@ -348,6 +255,17 @@ def resolve_mlp(self, prefix: str) -> nn.Module: prefix=f"{prefix}.mlp", ) + def _maybe_narrow_residual( + self, + residual: torch.Tensor, + ctx: ForwardContext, + ) -> torch.Tensor: + """Wrapper: align residual with attn output narrowed to [bs, H].""" + if ctx.draft_first_step_reduce and not ctx.forward_mode.is_idle(): + # Gather residual to self_attn's [bs, H]; idle has no gather_ids. + return residual.index_select(0, ctx.gather_ids) + return residual + def forward_low_latency( self, positions: torch.Tensor, @@ -383,9 +301,7 @@ def forward_low_latency( ctx=ctx, out_cache_loc=out_cache_loc, ) - if ctx.draft_first_step_reduce and not ctx.forward_mode.is_idle(): - # Gather residual to self_attn's [bs, H]; idle has no gather_ids. - residual = residual.index_select(0, ctx.gather_ids) + residual = self._maybe_narrow_residual(residual, ctx) # Fused post-attn allreduce + norm (uses attn tp group) block_scale = None @@ -451,9 +367,7 @@ def forward( ctx=ctx, out_cache_loc=out_cache_loc, ) - if ctx.draft_first_step_reduce and not ctx.forward_mode.is_idle(): - # Gather residual to self_attn's [bs, H]; idle has no gather_ids. - residual = residual.index_select(0, ctx.gather_ids) + residual = self._maybe_narrow_residual(residual, ctx) hidden_states, residual = self.comm_manager.post_attn_comm( hidden_states, residual, ctx ) @@ -583,11 +497,6 @@ class LlamaForCausalLMEagle3(BaseCausalLM): model_cls = Eagle3LlamaModel - # Eagle3 catch-up pre-slices q to active row before attn; trim must pair. - # TODO: remove together with the base flag once Qwen NextN / DeepSeek V3 - # NextN also pre-slice. - pre_attention_trim: bool = True - def __init__( self, config: LlamaConfig, From 6fbab67283600e17691ca30d047f55560a061328 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 01:13:11 +0000 Subject: [PATCH 02/20] update Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/llama_eagle3.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index b633bfbb8..2db130705 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -93,8 +93,9 @@ def _attn( # Active dispatch B: correction + q-slice + DECODE (fused) or post-slice (non-fused). self._apply_correction(ctx) if ctx.attn_backend.support_kv_cache_prewrite(ctx.forward_mode): - q_rope = self._fused_rope_kv_write(positions, q, k, v, ctx, out_cache_loc) - q_rope = q_rope.index_select(0, ctx.gather_ids) + q_rope = self._fused_rope_kv_write( + positions, q, k, v, ctx, out_cache_loc + ).index_select(0, ctx.gather_ids) return ctx.attn_backend.forward( q_rope, None, From 399b793e959079297cf835023c6d3ddab6c38b9d Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 03:54:14 +0000 Subject: [PATCH 03/20] feat(spec-decode): extend Llama Eagle3 dispatch B to prefill catch-up Signed-off-by: rjzhb --- .../runtime/execution/drafter/eagle.py | 6 +++- .../tokenspeed/runtime/models/llama_eagle3.py | 36 ++++++++----------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 6466c8290..9264b1788 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -218,7 +218,11 @@ def _run_first_step( draft_input, bs, draft_input.input_num_tokens ) input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id) - draft_first_step_reduce = forward_mode.is_decode() + # Llama Eagle3 narrows for prefill catch-up too; Qwen/DeepSeek do not. + draft_first_step_reduce = forward_mode.is_decode() or ( + isinstance(self.draft_model_runner.model, LlamaForCausalLMEagle3) + and forward_mode.is_target_verify() + ) draft_first_mode = ( ForwardMode.DRAFT_EXTEND diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 2db130705..e6494ef13 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -84,10 +84,10 @@ def _attn( ctx: ForwardContext, out_cache_loc: torch.Tensor, ) -> torch.Tensor: - # Active draft first step (= target verify was pure decode + drafter set - # up gather_ids). Other forwards (multi-step decode without gather_ids, - # or first step after a target verify with extends) delegate to base. - if ctx.gather_ids is None or not ctx.forward_mode.is_decode(): + # Active draft first step (drafter set up gather_ids + accept_lengths). + # Covers both decode catch-up and prefill catch-up; multi-step decode + # delegates to base. + if ctx.accept_lengths is None: return super()._attn(positions, q, k, v, ctx, out_cache_loc) # Active dispatch B: correction + q-slice + DECODE (fused) or post-slice (non-fused). @@ -113,22 +113,17 @@ def _attn( ) def _apply_correction(self, ctx: ForwardContext) -> None: - """Trim cache_seqlens by ``spec_num_tokens - accept_lengths``. - - Idempotent across draft layers via ``ctx.accept_lengths = None``. - No-op when drafter did not populate ``accept_lengths`` (e.g. backends - without KV pre-write). - """ - if ctx.accept_lengths is None: - return + """Trim decode rows' cache_seqlens by ``spec_num_tokens - accept_lengths``.""" seq_lens_buf = ctx.draft_seq_lens_buf - if seq_lens_buf is None: + if seq_lens_buf is None or ctx.accept_lengths is None: return - correction = (ctx.attn_backend.spec_num_tokens - ctx.accept_lengths).to( - seq_lens_buf.dtype - ) - seq_lens_buf[: ctx.bs].sub_(correction) - ctx.accept_lengths = None + num_extends = ctx.num_extends + if num_extends >= ctx.bs: + return + correction = ( + ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] + ).to(seq_lens_buf.dtype) + seq_lens_buf[num_extends : ctx.bs].sub_(correction) # --------------------------------------------------------------------------- @@ -261,9 +256,8 @@ def _maybe_narrow_residual( residual: torch.Tensor, ctx: ForwardContext, ) -> torch.Tensor: - """Wrapper: align residual with attn output narrowed to [bs, H].""" - if ctx.draft_first_step_reduce and not ctx.forward_mode.is_idle(): - # Gather residual to self_attn's [bs, H]; idle has no gather_ids. + """Align residual with attn output narrowed to [bs, H].""" + if ctx.accept_lengths is not None and not ctx.forward_mode.is_idle(): return residual.index_select(0, ctx.gather_ids) return residual From 90ec73a083e66905147e5992382f1829f74c2380 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 04:04:48 +0000 Subject: [PATCH 04/20] update Signed-off-by: rjzhb --- python/tokenspeed/runtime/execution/drafter/eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 9264b1788..99e8e8b24 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -27,6 +27,7 @@ from tokenspeed_kernel.ops.sampling.cute_dsl import argmax as cute_argmax from typing_extensions import override +from python.tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 from tokenspeed.runtime.execution.cache_loc_kernel import ( compute_out_cache_loc_uniform, ) From 2ca7ebe69370ce25e6f24e8f0aadcbc6e1b3ce38 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 19:51:17 +0000 Subject: [PATCH 05/20] fix(spec-decode): correct LlamaForCausalLMEagle3 import path Signed-off-by: rjzhb --- python/tokenspeed/runtime/execution/drafter/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 7a5494b0a..db79965a1 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -27,7 +27,6 @@ from tokenspeed_kernel.ops.sampling import argmax as sampling_argmax from typing_extensions import override -from python.tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 from tokenspeed.runtime.execution.cache_loc_kernel import ( compute_out_cache_loc_uniform, ) @@ -37,6 +36,7 @@ CaptureHiddenMode, ForwardMode, ) +from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 from tokenspeed.runtime.multimodal.inputs import maybe_substitute_mm_pad from tokenspeed.runtime.utils import get_colorful_logger from tokenspeed.runtime.utils.nvtx import nvtx_range From cc698bf429e110c1fc36f5ac192da54d3d14c46e Mon Sep 17 00:00:00 2001 From: rjzhb Date: Tue, 9 Jun 2026 20:04:56 +0000 Subject: [PATCH 06/20] fix(spec-decode): cover EXTEND/MIXED catch-up in dispatch B flag broaden Signed-off-by: rjzhb --- python/tokenspeed/runtime/execution/drafter/eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index db79965a1..ef96a782b 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -219,10 +219,11 @@ def _run_first_step( draft_input, bs, draft_input.input_num_tokens ) input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id) - # Llama Eagle3 narrows for prefill catch-up too; Qwen/DeepSeek do not. + # Llama Eagle3 narrows for any non-idle catch-up (EXTEND/MIXED/ + # TARGET_VERIFY/DECODE); Qwen/DeepSeek keep is_decode() only. draft_first_step_reduce = forward_mode.is_decode() or ( isinstance(self.draft_model_runner.model, LlamaForCausalLMEagle3) - and forward_mode.is_target_verify() + and not forward_mode.is_idle() ) draft_first_mode = ( From 5ac5d7976fb8a3710899b3ff942e7c4683375c76 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Wed, 10 Jun 2026 19:11:01 +0000 Subject: [PATCH 07/20] fix(spec-decode): fall back when fused KV prewrite arg is None Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/llama.py | 57 ++++++++++++------- .../tokenspeed/runtime/models/llama_eagle3.py | 30 +++++----- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/python/tokenspeed/runtime/models/llama.py b/python/tokenspeed/runtime/models/llama.py index 0fc39e1f1..3f99c7eb6 100644 --- a/python/tokenspeed/runtime/models/llama.py +++ b/python/tokenspeed/runtime/models/llama.py @@ -222,41 +222,56 @@ def _attn( ) -> torch.Tensor: """RoPE + attention (pre-o_proj), with optional fused KV pre-write. - When the backend supports KV pre-write, fused rope writes KV directly - into the cache so the attention call can run with ``save_kv_cache=False`` - (saves one kernel launch). Subclasses (e.g. Eagle3 draft head) override - this hook to insert spec-decode behaviour around the same scaffolding. + When the backend supports KV pre-write *and* ``create_fused_set_kv_buffer_arg`` + accepts the layer's scales, fused rope writes KV directly into the cache + so the attention call can run with ``save_kv_cache=False`` (saves one + kernel launch). Otherwise we fall back to plain RoPE + ``self.attn(q, k, v)`` + so the backend writes KV the normal way — without this fallback, layers + with non-trivial k/v scales silently lose their KV writes. Subclasses + (e.g. Eagle3 draft head) override this hook to insert spec-decode + behaviour around the same scaffolding. """ if ctx.attn_backend.support_kv_cache_prewrite(ctx.forward_mode): - q_rope = self._fused_rope_kv_write(positions, q, k, v, ctx, out_cache_loc) - return self.attn( - q_rope, - None, - None, - save_kv_cache=False, - ctx=ctx, - out_cache_loc=out_cache_loc, - ) + fused_kv_arg = self._build_fused_kv_arg(v, ctx, out_cache_loc) + if fused_kv_arg is not None: + q_rope = self._fused_rope_kv_write(positions, q, k, fused_kv_arg) + return self.attn( + q_rope, + None, + None, + save_kv_cache=False, + ctx=ctx, + out_cache_loc=out_cache_loc, + ) q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc) - def _fused_rope_kv_write( + def _build_fused_kv_arg( self, - positions: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, v: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, - ) -> torch.Tensor: - """Fused RoPE that writes KV into cache and returns the rope'd Q.""" - n = q.shape[0] - fused_kv_arg = create_fused_set_kv_buffer_arg( + ): + """Try to build the fused RoPE+KV-write descriptor; returns ``None`` if + the helper rejects the layer (e.g. non-trivial k/v scales).""" + n = v.shape[0] + return create_fused_set_kv_buffer_arg( value=v.view(n, self.num_kv_heads, self.head_dim), layer=self.attn, out_cache_loc=out_cache_loc, token_to_kv_pool=ctx.token_to_kv_pool, ) + + def _fused_rope_kv_write( + self, + positions: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + fused_kv_arg, + ) -> torch.Tensor: + """Fused RoPE that writes KV into cache (via ``fused_kv_arg``) and + returns the rope'd Q.""" + n = q.shape[0] q_rope = torch.empty((n, self.q_size), dtype=q.dtype, device=q.device) self.rotary_emb( positions, diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index e6494ef13..ee8942975 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -93,20 +93,22 @@ def _attn( # Active dispatch B: correction + q-slice + DECODE (fused) or post-slice (non-fused). self._apply_correction(ctx) if ctx.attn_backend.support_kv_cache_prewrite(ctx.forward_mode): - q_rope = self._fused_rope_kv_write( - positions, q, k, v, ctx, out_cache_loc - ).index_select(0, ctx.gather_ids) - return ctx.attn_backend.forward( - q_rope, - None, - None, - self.attn, - out_cache_loc, - ctx.token_to_kv_pool, - ForwardMode.DECODE, - ctx.bs, - save_kv_cache=False, - ) + fused_kv_arg = self._build_fused_kv_arg(v, ctx, out_cache_loc) + if fused_kv_arg is not None: + q_rope = self._fused_rope_kv_write( + positions, q, k, fused_kv_arg + ).index_select(0, ctx.gather_ids) + return ctx.attn_backend.forward( + q_rope, + None, + None, + self.attn, + out_cache_loc, + ctx.token_to_kv_pool, + ForwardMode.DECODE, + ctx.bs, + save_kv_cache=False, + ) q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( 0, ctx.gather_ids From 555252f1fcf57741f09c6ae87a8322da71b38ad6 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Wed, 10 Jun 2026 20:33:49 +0000 Subject: [PATCH 08/20] refactor(spec-decode): wrap Qwen3.5 NextN attention via base hooks Signed-off-by: rjzhb --- .../runtime/execution/drafter/eagle.py | 10 +- .../runtime/execution/model_executor.py | 19 +++- python/tokenspeed/runtime/models/qwen3_5.py | 75 ++++++++----- .../runtime/models/qwen3_5_nextn.py | 101 +++++++++++++++++- 4 files changed, 172 insertions(+), 33 deletions(-) diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index 67317f012..ac7892e37 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -37,6 +37,7 @@ ForwardMode, ) from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 +from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN from tokenspeed.runtime.multimodal.inputs import maybe_substitute_mm_pad from tokenspeed.runtime.utils import get_colorful_logger from tokenspeed.runtime.utils.nvtx import nvtx_range @@ -218,10 +219,13 @@ def _run_first_step( draft_input, bs, draft_input.input_num_tokens ) input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id) - # Llama Eagle3 narrows for any non-idle catch-up (EXTEND/MIXED/ - # TARGET_VERIFY/DECODE); Qwen/DeepSeek keep is_decode() only. + # Llama Eagle3 and Qwen3.5 NextN narrow for any non-idle catch-up + # (EXTEND/MIXED/TARGET_VERIFY/DECODE); DeepSeek keeps is_decode() only. draft_first_step_reduce = forward_mode.is_decode() or ( - isinstance(self.draft_model_runner.model, LlamaForCausalLMEagle3) + isinstance( + self.draft_model_runner.model, + (LlamaForCausalLMEagle3, Qwen3_5ForConditionalGenerationNextN), + ) and not forward_mode.is_idle() ) diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index df9276488..39a64a329 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -46,6 +46,8 @@ from tokenspeed.runtime.execution.types import ModelExecutionResult from tokenspeed.runtime.grammar.capturable_grammar import setup_grammar_step from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput +from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3 +from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN from tokenspeed.runtime.sampling.backends.base import SamplingBackend from tokenspeed.runtime.sampling.dp_sampling_config import ( DpSamplingRuntimeLimits, @@ -1078,7 +1080,22 @@ def execute_idle_forward( global_num_tokens=draft_global_num_tokens, global_bs=global_bs, all_decode_or_idle=all_decode_or_idle, - draft_first_step_reduce=(step_idx == 0 and all_decode_or_idle), + # Mirror the active-rank broaden in eagle.py: Llama Eagle3 + # and Qwen3.5 NextN narrow for any non-idle catch-up, so the + # idle peer must size collectives the same way. + draft_first_step_reduce=( + step_idx == 0 + and ( + all_decode_or_idle + or isinstance( + self.drafter.draft_model_runner.model, + ( + LlamaForCausalLMEagle3, + Qwen3_5ForConditionalGenerationNextN, + ), + ) + ) + ), ) self.drafter.draft_model_runner.forward( draft_ctx, diff --git a/python/tokenspeed/runtime/models/qwen3_5.py b/python/tokenspeed/runtime/models/qwen3_5.py index 28a6ebd97..be12bd108 100644 --- a/python/tokenspeed/runtime/models/qwen3_5.py +++ b/python/tokenspeed/runtime/models/qwen3_5.py @@ -710,16 +710,13 @@ def _apply_qk_norm( self.q_norm.variance_epsilon, ) - def self_attention( + def _project_qkv_rope( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ctx: ForwardContext, - out_cache_loc: torch.Tensor, - ) -> torch.Tensor: - """Full attention forward pass.""" + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """qkv_proj + split + rope (+ optional gate). ``gate`` is ``None`` when ``attn_output_gate=False``.""" qkv, _ = self.qkv_proj(hidden_states) - if self.attn_output_gate: q_gate, k, v = qkv.split( [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 @@ -737,23 +734,48 @@ def self_attention( self.head_dim, self.rotary_emb.rotary_dim, ) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + return q, k, v, gate + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + return q, k, v, None + def _attn( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gate: torch.Tensor | None, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """Backend attention call + optional gate apply. Subclasses override.""" attn_output = self.attn(q, k, v, ctx, out_cache_loc) - - if self.attn_output_gate: + if gate is not None: sigmoid_mul(attn_output, gate) + return attn_output - if ctx.draft_first_step_reduce: - # Slice attn_output to [bs, H] so o_proj runs on live rows only. - attn_output = attn_output.index_select(0, ctx.gather_ids) - + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + """Full attention forward pass.""" + q, k, v, gate = self._project_qkv_rope(positions, hidden_states) + attn_output = self._attn(q, k, v, gate, ctx, out_cache_loc) output, _ = self.o_proj(attn_output) return output + def _maybe_narrow_residual( + self, + residual: torch.Tensor, + ctx: ForwardContext, + ) -> torch.Tensor: + """Hook: subclasses narrow residual to match a sliced attn output.""" + return residual + def forward( self, positions: torch.Tensor, @@ -778,9 +800,7 @@ def forward( ctx=ctx, out_cache_loc=out_cache_loc, ) - if ctx.draft_first_step_reduce: - # Gather residual to self_attention's [bs, H]. - residual = residual.index_select(0, ctx.gather_ids) + residual = self._maybe_narrow_residual(residual, ctx) hidden_states, residual = self.comm_manager.post_attn_reduce_norm( hidden_states, residual, ctx ) @@ -816,15 +836,12 @@ def forward_mlp( return hidden_states -ALL_DECODER_LAYER_TYPES = { - "attention": Qwen3_5AttentionDecoderLayer, - "linear_attention": Qwen3_5LinearDecoderLayer, -} - - class Qwen3_5ForCausalLM(nn.Module): """Qwen3.5 Model with support for dense variant.""" + ATTENTION_LAYER_CLS: type = Qwen3_5AttentionDecoderLayer + LINEAR_LAYER_CLS: type = Qwen3_5LinearDecoderLayer + def __init__( self, config: Qwen3_5TextConfig, @@ -849,10 +866,14 @@ def __init__( tp_group=self.mapping.attn.tp_group, ) - # Decoder layers + layer_cls_by_type = { + "attention": self.ATTENTION_LAYER_CLS, + "linear_attention": self.LINEAR_LAYER_CLS, + } + def get_layer(idx: int, prefix: str): layer_type = config.layers_block_type[idx] - layer_class = ALL_DECODER_LAYER_TYPES[layer_type] + layer_class = layer_cls_by_type[layer_type] if layer_type == "attention": prefix = add_prefix("self_attn", prefix) else: diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 85826266c..9e544e4fb 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -24,11 +24,13 @@ from collections.abc import Iterable import torch +from tokenspeed_kernel.ops.activation.triton import sigmoid_mul from torch import nn from transformers import PretrainedConfig from tokenspeed.runtime.distributed.mapping import Mapping from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode from tokenspeed.runtime.layers.layernorm import GemmaRMSNorm from tokenspeed.runtime.layers.linear import ReplicatedLinear from tokenspeed.runtime.layers.logits_processor import LogitsMetadata, LogitsProcessor @@ -38,12 +40,107 @@ ) from tokenspeed.runtime.layers.vocab_parallel_embedding import ParallelLMHead from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader -from tokenspeed.runtime.models.qwen3_5 import Qwen3_5ForCausalLM +from tokenspeed.runtime.models.qwen3_5 import ( + Qwen3_5AttentionDecoderLayer, + Qwen3_5ForCausalLM, +) from tokenspeed.runtime.utils import add_prefix logger = logging.getLogger(__name__) +class Qwen3_5DraftAttentionDecoderLayer(Qwen3_5AttentionDecoderLayer): + """NextN draft variant: skip dead catch-up rows on the first draft step. + + On the first draft step the backend runs in DECODE mode with ``q`` sliced + to ``bs`` while ``self.attn`` still writes the full ``N`` rope-d KV rows + from the just-drafted tokens. Multi-step decode delegates to base. + + MIXED catch-up requires a backend that populates a decode-slot metadata + under EXTEND/MIXED at draft init (e.g. trtllm-mha); MHA-family backends + that assert ``not is_mixed()`` at metadata init are not supported. + """ + + def _attn( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gate: torch.Tensor | None, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + if ctx.accept_lengths is None or ctx.forward_mode.is_idle(): + return super()._attn(q, k, v, gate, ctx, out_cache_loc) + + self._apply_correction(ctx) + q = q.index_select(0, ctx.gather_ids) + if gate is not None: + gate = gate.index_select(0, ctx.gather_ids) + attn_output = ctx.attn_backend.forward( + q, + k, + v, + self.attn, + out_cache_loc, + ctx.token_to_kv_pool, + ForwardMode.DECODE, + ctx.bs, + save_kv_cache=True, + ) + if gate is not None: + sigmoid_mul(attn_output, gate) + return attn_output + + def _apply_correction(self, ctx: ForwardContext) -> None: + """Trim decode rows' cache_seqlens by ``spec_num_tokens - accept_lengths``.""" + seq_lens_buf = ctx.draft_seq_lens_buf + if seq_lens_buf is None or ctx.accept_lengths is None: + return + num_extends = ctx.num_extends + if num_extends >= ctx.bs: + return + correction = ( + ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] + ).to(seq_lens_buf.dtype) + seq_lens_buf[num_extends : ctx.bs].sub_(correction) + + def _maybe_narrow_residual( + self, + residual: torch.Tensor, + ctx: ForwardContext, + ) -> torch.Tensor: + if ctx.accept_lengths is None or ctx.forward_mode.is_idle(): + return residual + return residual.index_select(0, ctx.gather_ids) + + +class Qwen3_5DraftForCausalLM(Qwen3_5ForCausalLM): + """Causal LM with the draft-variant attention layer injected. + + Restricted to single-layer drafts: ``_apply_correction`` mutates + ``ctx.draft_seq_lens_buf`` in place and is not idempotent across layers. + A multi-layer draft would double-trim cache_seqlens. Lift the correction + out of the per-layer hook (e.g. into the drafter) before relaxing this. + """ + + ATTENTION_LAYER_CLS: type = Qwen3_5DraftAttentionDecoderLayer + + def __init__( + self, + config, + mapping, + quant_config=None, + prefix: str = "", + ) -> None: + assert config.num_hidden_layers == 1, ( + "Qwen3_5DraftForCausalLM requires num_hidden_layers == 1 " + f"(got {config.num_hidden_layers}); _apply_correction is not " + "idempotent across layers." + ) + super().__init__(config, mapping, quant_config=quant_config, prefix=prefix) + + class Qwen3_5ForConditionalGenerationNextN(nn.Module): def __init__( self, @@ -74,7 +171,7 @@ def __init__( self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps) config.num_hidden_layers = 1 config.full_attention_interval = 1 - self.model = Qwen3_5ForCausalLM( + self.model = Qwen3_5DraftForCausalLM( config, mapping=self.mapping, quant_config=quant_config, From fafc901f1ca09bc9f33343ce768ce8d82e09ed18 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Thu, 11 Jun 2026 20:13:55 +0000 Subject: [PATCH 09/20] fix(qwen3.5-nextn): drop idle-mode early return in draft attn Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 9e544e4fb..942edb9c3 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -70,7 +70,7 @@ def _attn( ctx: ForwardContext, out_cache_loc: torch.Tensor, ) -> torch.Tensor: - if ctx.accept_lengths is None or ctx.forward_mode.is_idle(): + if ctx.accept_lengths is None: return super()._attn(q, k, v, gate, ctx, out_cache_loc) self._apply_correction(ctx) From 404279bd417a1bc84d7ef9fb1a26da8d0fd1e053 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Fri, 12 Jun 2026 05:00:02 +0000 Subject: [PATCH 10/20] fix(spec-decode): pad decode seq_lens to spec_num_tokens Signed-off-by: rjzhb --- python/tokenspeed/runtime/execution/input_buffer.py | 11 +++++++++-- python/tokenspeed/runtime/execution/model_executor.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index c28624865..186bb72bc 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -56,6 +56,7 @@ def __init__( dummy_kv_slot: int, device: str = "cuda", has_mamba: bool = False, + min_padding_seq_len: int = 1, ): self.device = device self.page_size = page_size @@ -64,6 +65,10 @@ def __init__( self.max_bs = max_bs self.all_extends_mid_chunk = False self.has_mamba = has_mamba + # Padding rows' seq_lens; >= spec_num_tokens so the EAGLE catch-up trim + # (seq_lens -= spec_num_tokens - accept_len) can't underflow them. 1 for + # non-spec. + self.min_padding_seq_len = min_padding_seq_len with torch.device(device): # Initialise buffers to the *padding* values the captured graph @@ -385,7 +390,7 @@ def write_decode_input_ids( self.mrope_positions_buf[:, total_tokens:].zero_() if batch_size < self.max_bs: self.req_pool_indices_buf[batch_size:].fill_(0) - self.seq_lens_buf[batch_size:].fill_(1) + self.seq_lens_buf[batch_size:].fill_(self.min_padding_seq_len) if ( self.has_mamba @@ -437,4 +442,6 @@ def fill_dummy_decode_buffers(self, batch_size: int, total_tokens: int): # seq_lens must be >= spec_num_tokens so the drafter's prewrite # correction never goes negative. num_tokens_per_req = total_tokens // batch_size if batch_size > 0 else 1 - self.seq_lens_buf[:batch_size].fill_(max(num_tokens_per_req, 1)) + self.seq_lens_buf[:batch_size].fill_( + max(num_tokens_per_req, self.min_padding_seq_len) + ) diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 39a64a329..7e88e416a 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -234,6 +234,7 @@ def __init__( dummy_kv_slot=0, device=self.device, has_mamba=(mamba_pool is not None), + min_padding_seq_len=spec_num_tokens, ) self.runtime_states = RuntimeStates( req_pool_size=config.max_req_pool_size, From e9e39d068ed118db8e900ce4522f1be2daa4b463 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Fri, 12 Jun 2026 18:28:34 +0000 Subject: [PATCH 11/20] Revert "fix(spec-decode): pad decode seq_lens to spec_num_tokens" This reverts commit 404279bd417a1bc84d7ef9fb1a26da8d0fd1e053. --- python/tokenspeed/runtime/execution/input_buffer.py | 11 ++--------- python/tokenspeed/runtime/execution/model_executor.py | 1 - 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index 186bb72bc..c28624865 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -56,7 +56,6 @@ def __init__( dummy_kv_slot: int, device: str = "cuda", has_mamba: bool = False, - min_padding_seq_len: int = 1, ): self.device = device self.page_size = page_size @@ -65,10 +64,6 @@ def __init__( self.max_bs = max_bs self.all_extends_mid_chunk = False self.has_mamba = has_mamba - # Padding rows' seq_lens; >= spec_num_tokens so the EAGLE catch-up trim - # (seq_lens -= spec_num_tokens - accept_len) can't underflow them. 1 for - # non-spec. - self.min_padding_seq_len = min_padding_seq_len with torch.device(device): # Initialise buffers to the *padding* values the captured graph @@ -390,7 +385,7 @@ def write_decode_input_ids( self.mrope_positions_buf[:, total_tokens:].zero_() if batch_size < self.max_bs: self.req_pool_indices_buf[batch_size:].fill_(0) - self.seq_lens_buf[batch_size:].fill_(self.min_padding_seq_len) + self.seq_lens_buf[batch_size:].fill_(1) if ( self.has_mamba @@ -442,6 +437,4 @@ def fill_dummy_decode_buffers(self, batch_size: int, total_tokens: int): # seq_lens must be >= spec_num_tokens so the drafter's prewrite # correction never goes negative. num_tokens_per_req = total_tokens // batch_size if batch_size > 0 else 1 - self.seq_lens_buf[:batch_size].fill_( - max(num_tokens_per_req, self.min_padding_seq_len) - ) + self.seq_lens_buf[:batch_size].fill_(max(num_tokens_per_req, 1)) diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 7e88e416a..39a64a329 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -234,7 +234,6 @@ def __init__( dummy_kv_slot=0, device=self.device, has_mamba=(mamba_pool is not None), - min_padding_seq_len=spec_num_tokens, ) self.runtime_states = RuntimeStates( req_pool_size=config.max_req_pool_size, From ca10fca6fdab212f3efedafe8d90af78ef4cb2ed Mon Sep 17 00:00:00 2001 From: rjzhb Date: Fri, 12 Jun 2026 18:28:45 +0000 Subject: [PATCH 12/20] fix(spec-decode): clamp catch-up trim result to avoid negative seq_lens Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/llama_eagle3.py | 2 +- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 258f7c703..b56e1c606 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -126,7 +126,7 @@ def _apply_correction(self, ctx: ForwardContext) -> None: correction = ( ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] ).to(seq_lens_buf.dtype) - seq_lens_buf[num_extends : ctx.bs].sub_(correction) + seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1) # --------------------------------------------------------------------------- diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index c2b089915..636455805 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -103,7 +103,7 @@ def _apply_correction(self, ctx: ForwardContext) -> None: correction = ( ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:] ).to(seq_lens_buf.dtype) - seq_lens_buf[num_extends : ctx.bs].sub_(correction) + seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1) def _maybe_narrow_residual( self, From 263ef7010e9fc19b8632921f7dd0161ee5ed055b Mon Sep 17 00:00:00 2001 From: rjzhb Date: Fri, 12 Jun 2026 23:50:04 +0000 Subject: [PATCH 13/20] chore(qwen3.5-nextn): move sigmoid_mul import below first-party group Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 636455805..6c1bfd2b8 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -24,7 +24,6 @@ from collections.abc import Iterable import torch -from tokenspeed_kernel.ops.activation.triton import sigmoid_mul from torch import nn from transformers import PretrainedConfig @@ -46,6 +45,11 @@ ) from tokenspeed.runtime.utils import add_prefix +# isort: off +from tokenspeed_kernel.ops.activation.triton import sigmoid_mul + +# isort: on + logger = logging.getLogger(__name__) From 43643231780b1478cf36425444076bd6c206eec6 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Fri, 12 Jun 2026 23:53:26 +0000 Subject: [PATCH 14/20] chore(qwen3.5-nextn): drop isort skip markers, accept canonical order Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 6c1bfd2b8..636455805 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -24,6 +24,7 @@ from collections.abc import Iterable import torch +from tokenspeed_kernel.ops.activation.triton import sigmoid_mul from torch import nn from transformers import PretrainedConfig @@ -45,11 +46,6 @@ ) from tokenspeed.runtime.utils import add_prefix -# isort: off -from tokenspeed_kernel.ops.activation.triton import sigmoid_mul - -# isort: on - logger = logging.getLogger(__name__) From 72ddde04acdf970df1938e1ebd40ca2666e6d6be Mon Sep 17 00:00:00 2001 From: rjzhb Date: Sun, 14 Jun 2026 03:37:13 +0000 Subject: [PATCH 15/20] fix(pd): record draft layerwise cache step on EXTEND catch-up Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/llama_eagle3.py | 8 +++++++- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index b56e1c606..ee5b5f904 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -99,7 +99,7 @@ def _attn( q_rope = self._fused_rope_kv_write( positions, q, k, fused_kv_arg ).index_select(0, ctx.gather_ids) - return ctx.attn_backend.forward( + attn_output = ctx.attn_backend.forward( q_rope, None, None, @@ -110,6 +110,12 @@ def _attn( ctx.bs, save_kv_cache=False, ) + step_counter = ctx.attn_backend.step_counter + if step_counter is not None and not ctx.forward_mode.is_decode_or_idle(): + # The backend call above intentionally uses DECODE metadata, + # bypassing EXTEND-side layerwise cache-step accounting. + step_counter.record_cache() + return attn_output q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( 0, ctx.gather_ids diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 636455805..1c8e4304e 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -88,6 +88,11 @@ def _attn( ctx.bs, save_kv_cache=True, ) + step_counter = ctx.attn_backend.step_counter + if step_counter is not None and not ctx.forward_mode.is_decode_or_idle(): + # The backend call above intentionally uses DECODE metadata, which + # bypasses the backend's EXTEND-side layerwise cache-step accounting. + step_counter.record_cache() if gate is not None: sigmoid_mul(attn_output, gate) return attn_output From 2806ea44fdd89817103e909a14c02997beafe33f Mon Sep 17 00:00:00 2001 From: rjzhb Date: Sun, 14 Jun 2026 06:27:47 +0000 Subject: [PATCH 16/20] fix(spec-decode): route draft attn via self.attn; guard step_counter Signed-off-by: rjzhb --- .../tokenspeed/runtime/models/llama_eagle3.py | 8 +++-- .../runtime/models/qwen3_5_nextn.py | 31 +++++++++---------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index ee5b5f904..d3576f8d1 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -110,11 +110,13 @@ def _attn( ctx.bs, save_kv_cache=False, ) - step_counter = ctx.attn_backend.step_counter - if step_counter is not None and not ctx.forward_mode.is_decode_or_idle(): + if ( + getattr(ctx.attn_backend, "step_counter", None) + and not ctx.forward_mode.is_decode_or_idle() + ): # The backend call above intentionally uses DECODE metadata, # bypassing EXTEND-side layerwise cache-step accounting. - step_counter.record_cache() + ctx.attn_backend.step_counter.record_cache() return attn_output q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 1c8e4304e..9a54fb1de 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -22,6 +22,7 @@ import logging from collections.abc import Iterable +from dataclasses import replace import torch from tokenspeed_kernel.ops.activation.triton import sigmoid_mul @@ -77,22 +78,20 @@ def _attn( q = q.index_select(0, ctx.gather_ids) if gate is not None: gate = gate.index_select(0, ctx.gather_ids) - attn_output = ctx.attn_backend.forward( - q, - k, - v, - self.attn, - out_cache_loc, - ctx.token_to_kv_pool, - ForwardMode.DECODE, - ctx.bs, - save_kv_cache=True, - ) - step_counter = ctx.attn_backend.step_counter - if step_counter is not None and not ctx.forward_mode.is_decode_or_idle(): - # The backend call above intentionally uses DECODE metadata, which - # bypasses the backend's EXTEND-side layerwise cache-step accounting. - step_counter.record_cache() + # Route through self.attn so the backend reshapes k/v and writes the KV + # cache (the KV write depends only on k/v, so the sliced q above is + # safe). Force DECODE via a ctx copy: the sliced q is one live query per + # request, which matches the decode metadata's [bs, 1] semantics (the + # extend metadata expects the full N-token segment). + decode_ctx = replace(ctx, forward_mode=ForwardMode.DECODE) + attn_output = self.attn(q, k, v, decode_ctx, out_cache_loc) + if ( + getattr(ctx.attn_backend, "step_counter", None) + and not ctx.forward_mode.is_decode_or_idle() + ): + # DECODE metadata above bypasses the backend's EXTEND-side layerwise + # cache-step accounting; record it here. + ctx.attn_backend.step_counter.record_cache() if gate is not None: sigmoid_mul(attn_output, gate) return attn_output From b9645c4e0df5fd2352d58e9a19c96e69cc0ba9a6 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Sun, 14 Jun 2026 06:55:33 +0000 Subject: [PATCH 17/20] docs(spec-decode): clarify draft catch-up attention comments Signed-off-by: rjzhb --- python/tokenspeed/runtime/models/llama_eagle3.py | 6 ++++-- python/tokenspeed/runtime/models/qwen3_5_nextn.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index d3576f8d1..76022e975 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -114,8 +114,10 @@ def _attn( getattr(ctx.attn_backend, "step_counter", None) and not ctx.forward_mode.is_decode_or_idle() ): - # The backend call above intentionally uses DECODE metadata, - # bypassing EXTEND-side layerwise cache-step accounting. + # Under PD disaggregation the backend records a layerwise + # cache step on its EXTEND path so KV transfer can track + # per-layer readiness. The DECODE call above skips that, so + # record it here to keep an EXTEND/MIXED catch-up in sync. ctx.attn_backend.step_counter.record_cache() return attn_output q, k = self.rotary_emb(positions, q, k) diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 9a54fb1de..6cbc591d2 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -78,19 +78,20 @@ def _attn( q = q.index_select(0, ctx.gather_ids) if gate is not None: gate = gate.index_select(0, ctx.gather_ids) - # Route through self.attn so the backend reshapes k/v and writes the KV - # cache (the KV write depends only on k/v, so the sliced q above is - # safe). Force DECODE via a ctx copy: the sliced q is one live query per - # request, which matches the decode metadata's [bs, 1] semantics (the - # extend metadata expects the full N-token segment). + # The catch-up runs as DECODE over the sliced live rows (see the class + # docstring). Going through self.attn keeps the standard k/v reshape and + # KV-cache write that every other attention call relies on, rather than + # invoking the backend directly. decode_ctx = replace(ctx, forward_mode=ForwardMode.DECODE) attn_output = self.attn(q, k, v, decode_ctx, out_cache_loc) if ( getattr(ctx.attn_backend, "step_counter", None) and not ctx.forward_mode.is_decode_or_idle() ): - # DECODE metadata above bypasses the backend's EXTEND-side layerwise - # cache-step accounting; record it here. + # Under PD disaggregation the backend records a layerwise cache step + # on its EXTEND path so KV transfer can track per-layer readiness. + # The DECODE dispatch above skips that, so record it here to keep an + # EXTEND/MIXED catch-up in sync. ctx.attn_backend.step_counter.record_cache() if gate is not None: sigmoid_mul(attn_output, gate) From af1ab659ffad3bc222dd28cdc75c8355e0e34b65 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Mon, 15 Jun 2026 20:48:43 +0000 Subject: [PATCH 18/20] refactor(spec-decode): centralize draft cache-step record via record_kv_cache Signed-off-by: rjzhb --- .../runtime/layers/attention/backends/base.py | 40 +++++++++++-------- .../attention/backends/hybrid_linear_attn.py | 34 +++++++++------- .../tokenspeed/runtime/models/llama_eagle3.py | 16 +++----- .../runtime/models/qwen3_5_nextn.py | 27 ++++++------- 4 files changed, 60 insertions(+), 57 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index 6bc84a1a9..79e1d6b93 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -138,11 +138,29 @@ def forward( forward_mode: ForwardMode, bs: int, save_kv_cache: bool = True, + record_kv_cache: bool | None = None, **kwargs, ): - """Run forward on an attention layer with explicit scheduler metadata.""" + """Run forward on an attention layer with explicit scheduler metadata. + + ``record_kv_cache`` overrides the PD layerwise cache-step recording: + ``None`` keeps the default (record on the EXTEND-side path), an explicit + bool forces it so a DECODE-dispatched draft catch-up can still record. + """ + # Anchor the record to the KV write: before forward_extend when the KV + # was pre-written (save_kv_cache=False), after it otherwise. + if record_kv_cache is None: + record_cache = not forward_mode.is_decode() and not forward_mode.is_idle() + else: + record_cache = record_kv_cache + record_cache = record_cache and getattr(self, "step_counter", None) is not None + pre_attn_record = record_cache and not save_kv_cache + post_attn_record = record_cache and save_kv_cache + + if pre_attn_record: + self.step_counter.record_cache() if forward_mode.is_decode(): - return self.forward_decode( + ret = self.forward_decode( q, k, v, @@ -154,13 +172,6 @@ def forward( **kwargs, ) else: - if ( - not forward_mode.is_idle() - and getattr(self, "step_counter", None) - and not save_kv_cache - ): - self.step_counter.record_cache() - ret = self.forward_extend( q, k, @@ -173,14 +184,9 @@ def forward( forward_mode=forward_mode, **kwargs, ) - - if ( - not forward_mode.is_idle() - and getattr(self, "step_counter", None) - and save_kv_cache - ): - self.step_counter.record_cache() - return ret + if post_attn_record: + self.step_counter.record_cache() + return ret def forward_decode( self, diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index d0c5af50e..5d4eb928b 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -1263,6 +1263,7 @@ def forward( forward_mode: ForwardMode, bs: int, save_kv_cache: bool = True, + record_kv_cache: bool | None = None, **kwargs, ): if forward_mode is None: @@ -1276,6 +1277,7 @@ def forward( forward_mode, bs, save_kv_cache, + record_kv_cache=record_kv_cache, **kwargs, ) @@ -1287,8 +1289,21 @@ def forward( layer_id = layer.layer_id if layer else kwargs["layer_id"] backend = self._backend_for_layer(layer_id) + # See AttentionBackend.forward for the record_kv_cache contract; the step + # is recorded in this wrapper (not the child backends) to keep one step + # per model layer across full-attn + mamba. + if record_kv_cache is None: + record_cache = not forward_mode.is_decode() + else: + record_cache = record_kv_cache + record_cache = record_cache and getattr(self, "step_counter", None) is not None + pre_attn_record = record_cache and not save_kv_cache + post_attn_record = record_cache and save_kv_cache + + if pre_attn_record: + self.step_counter.record_cache() if forward_mode.is_decode(): - return backend.forward_decode( + ret = backend.forward_decode( q, k, v, @@ -1300,13 +1315,6 @@ def forward( **kwargs, ) else: - step_counter = getattr(self, "step_counter", None) - if ( - not forward_mode.is_idle() - and step_counter is not None - and not save_kv_cache - ): - step_counter.record_cache() ret = backend.forward_extend( q, k, @@ -1319,13 +1327,9 @@ def forward( forward_mode=forward_mode, **kwargs, ) - if ( - not forward_mode.is_idle() - and step_counter is not None - and save_kv_cache - ): - step_counter.record_cache() - return ret + if post_attn_record: + self.step_counter.record_cache() + return ret def forward_decode( self, q, k, v, layer, out_cache_loc, token_to_kv_pool, bs, **kwargs diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 76022e975..54d3c92b1 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -99,7 +99,10 @@ def _attn( q_rope = self._fused_rope_kv_write( positions, q, k, fused_kv_arg ).index_select(0, ctx.gather_ids) - attn_output = ctx.attn_backend.forward( + # record_kv_cache (keyed off the real mode) forces the backend's + # PD layerwise cache-step record that the DECODE dispatch would + # otherwise skip on an EXTEND/MIXED catch-up. + return ctx.attn_backend.forward( q_rope, None, None, @@ -109,17 +112,8 @@ def _attn( ForwardMode.DECODE, ctx.bs, save_kv_cache=False, + record_kv_cache=not ctx.forward_mode.is_decode_or_idle(), ) - if ( - getattr(ctx.attn_backend, "step_counter", None) - and not ctx.forward_mode.is_decode_or_idle() - ): - # Under PD disaggregation the backend records a layerwise - # cache step on its EXTEND path so KV transfer can track - # per-layer readiness. The DECODE call above skips that, so - # record it here to keep an EXTEND/MIXED catch-up in sync. - ctx.attn_backend.step_counter.record_cache() - return attn_output q, k = self.rotary_emb(positions, q, k) return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select( 0, ctx.gather_ids diff --git a/python/tokenspeed/runtime/models/qwen3_5_nextn.py b/python/tokenspeed/runtime/models/qwen3_5_nextn.py index 6cbc591d2..9d56a7b38 100644 --- a/python/tokenspeed/runtime/models/qwen3_5_nextn.py +++ b/python/tokenspeed/runtime/models/qwen3_5_nextn.py @@ -78,21 +78,20 @@ def _attn( q = q.index_select(0, ctx.gather_ids) if gate is not None: gate = gate.index_select(0, ctx.gather_ids) - # The catch-up runs as DECODE over the sliced live rows (see the class - # docstring). Going through self.attn keeps the standard k/v reshape and - # KV-cache write that every other attention call relies on, rather than - # invoking the backend directly. + # Dispatch as DECODE over the sliced live rows via self.attn (see the + # class docstring), which keeps the standard k/v reshape and KV write. + # A ctx copy overrides only the forward mode; record_kv_cache (keyed off + # the real mode) forces the backend's PD layerwise cache-step record that + # DECODE would otherwise skip on an EXTEND/MIXED catch-up. decode_ctx = replace(ctx, forward_mode=ForwardMode.DECODE) - attn_output = self.attn(q, k, v, decode_ctx, out_cache_loc) - if ( - getattr(ctx.attn_backend, "step_counter", None) - and not ctx.forward_mode.is_decode_or_idle() - ): - # Under PD disaggregation the backend records a layerwise cache step - # on its EXTEND path so KV transfer can track per-layer readiness. - # The DECODE dispatch above skips that, so record it here to keep an - # EXTEND/MIXED catch-up in sync. - ctx.attn_backend.step_counter.record_cache() + attn_output = self.attn( + q, + k, + v, + decode_ctx, + out_cache_loc, + record_kv_cache=not ctx.forward_mode.is_decode_or_idle(), + ) if gate is not None: sigmoid_mul(attn_output, gate) return attn_output From f8c9490869ef1a1843f44855ddfd1a2ebbba6cd9 Mon Sep 17 00:00:00 2001 From: rjzhb Date: Wed, 17 Jun 2026 19:20:27 +0000 Subject: [PATCH 19/20] refactor(attn): share pre/post cache-step record via record_cache_step Signed-off-by: rjzhb --- .../runtime/layers/attention/backends/base.py | 94 +++++++++++-------- .../attention/backends/hybrid_linear_attn.py | 65 ++++++------- 2 files changed, 82 insertions(+), 77 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index 79e1d6b93..98c75d097 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -127,6 +127,35 @@ def configure_runtime(self, **kwargs) -> None: def register_step_counter(self, step_counter: StepCounter): self.step_counter = step_counter + @contextmanager + def record_cache_step( + self, + forward_mode: ForwardMode, + save_kv_cache: bool, + record_kv_cache: bool | None, + ): + """Anchor the PD layerwise cache-step record to the wrapped KV write. + + Records the ``StepCounter`` step before the attention call when the KV + was pre-written (``save_kv_cache=False``) and after it otherwise, so a + layerwise cache transfer always observes a fully written layer. See + ``forward`` for the ``record_kv_cache`` override contract. No-op when no + step counter is registered. Backends that own the record (e.g. the + hybrid wrapper, which counts once per model layer across full-attn + + mamba children) reuse this to avoid duplicating the gate logic. + """ + if record_kv_cache is None: + record_cache = not forward_mode.is_decode() and not forward_mode.is_idle() + else: + record_cache = record_kv_cache + record_cache = record_cache and getattr(self, "step_counter", None) is not None + + if record_cache and not save_kv_cache: + self.step_counter.record_cache() + yield + if record_cache and save_kv_cache: + self.step_counter.record_cache() + def forward( self, q: torch.Tensor, @@ -147,45 +176,32 @@ def forward( ``None`` keeps the default (record on the EXTEND-side path), an explicit bool forces it so a DECODE-dispatched draft catch-up can still record. """ - # Anchor the record to the KV write: before forward_extend when the KV - # was pre-written (save_kv_cache=False), after it otherwise. - if record_kv_cache is None: - record_cache = not forward_mode.is_decode() and not forward_mode.is_idle() - else: - record_cache = record_kv_cache - record_cache = record_cache and getattr(self, "step_counter", None) is not None - pre_attn_record = record_cache and not save_kv_cache - post_attn_record = record_cache and save_kv_cache - - if pre_attn_record: - self.step_counter.record_cache() - if forward_mode.is_decode(): - ret = self.forward_decode( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - **kwargs, - ) - else: - ret = self.forward_extend( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - forward_mode=forward_mode, - **kwargs, - ) - if post_attn_record: - self.step_counter.record_cache() + with self.record_cache_step(forward_mode, save_kv_cache, record_kv_cache): + if forward_mode.is_decode(): + ret = self.forward_decode( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + ret = self.forward_extend( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + forward_mode=forward_mode, + **kwargs, + ) return ret def forward_decode( diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index 5d4eb928b..1a857b353 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -1291,44 +1291,33 @@ def forward( # See AttentionBackend.forward for the record_kv_cache contract; the step # is recorded in this wrapper (not the child backends) to keep one step - # per model layer across full-attn + mamba. - if record_kv_cache is None: - record_cache = not forward_mode.is_decode() - else: - record_cache = record_kv_cache - record_cache = record_cache and getattr(self, "step_counter", None) is not None - pre_attn_record = record_cache and not save_kv_cache - post_attn_record = record_cache and save_kv_cache - - if pre_attn_record: - self.step_counter.record_cache() - if forward_mode.is_decode(): - ret = backend.forward_decode( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - **kwargs, - ) - else: - ret = backend.forward_extend( - q, - k, - v, - layer, - out_cache_loc, - token_to_kv_pool, - bs, - save_kv_cache=save_kv_cache, - forward_mode=forward_mode, - **kwargs, - ) - if post_attn_record: - self.step_counter.record_cache() + # per model layer across full-attn + mamba. Idle already returned above. + with self.record_cache_step(forward_mode, save_kv_cache, record_kv_cache): + if forward_mode.is_decode(): + ret = backend.forward_decode( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + ret = backend.forward_extend( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + bs, + save_kv_cache=save_kv_cache, + forward_mode=forward_mode, + **kwargs, + ) return ret def forward_decode( From d609578ff46a313e390cddf6fa49e8abfab8236b Mon Sep 17 00:00:00 2001 From: rjzhb Date: Wed, 17 Jun 2026 19:30:34 +0000 Subject: [PATCH 20/20] refactor(attn): rename helper to record_pd_cache_step Signed-off-by: rjzhb --- python/tokenspeed/runtime/layers/attention/backends/base.py | 4 ++-- .../runtime/layers/attention/backends/hybrid_linear_attn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index 98c75d097..fa95c3025 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -128,7 +128,7 @@ def register_step_counter(self, step_counter: StepCounter): self.step_counter = step_counter @contextmanager - def record_cache_step( + def record_pd_cache_step( self, forward_mode: ForwardMode, save_kv_cache: bool, @@ -176,7 +176,7 @@ def forward( ``None`` keeps the default (record on the EXTEND-side path), an explicit bool forces it so a DECODE-dispatched draft catch-up can still record. """ - with self.record_cache_step(forward_mode, save_kv_cache, record_kv_cache): + with self.record_pd_cache_step(forward_mode, save_kv_cache, record_kv_cache): if forward_mode.is_decode(): ret = self.forward_decode( q, diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index 1a857b353..b1158a54a 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -1292,7 +1292,7 @@ def forward( # See AttentionBackend.forward for the record_kv_cache contract; the step # is recorded in this wrapper (not the child backends) to keep one step # per model layer across full-attn + mamba. Idle already returned above. - with self.record_cache_step(forward_mode, save_kv_cache, record_kv_cache): + with self.record_pd_cache_step(forward_mode, save_kv_cache, record_kv_cache): if forward_mode.is_decode(): ret = backend.forward_decode( q,