Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d7882e7
refactor(spec-decode): wrap Eagle3 attention via base llama._attn
rjzhb Jun 9, 2026
6fbab67
update
rjzhb Jun 9, 2026
399b793
feat(spec-decode): extend Llama Eagle3 dispatch B to prefill catch-up
rjzhb Jun 9, 2026
90ec73a
update
rjzhb Jun 9, 2026
3dbdb30
Merge branch 'main' into refactor/llama-attention-hooks
rjzhb Jun 9, 2026
2ca7ebe
fix(spec-decode): correct LlamaForCausalLMEagle3 import path
rjzhb Jun 9, 2026
cc698bf
fix(spec-decode): cover EXTEND/MIXED catch-up in dispatch B flag broaden
rjzhb Jun 9, 2026
b0b485d
Merge remote-tracking branch 'upstream/main' into refactor/llama-atte…
rjzhb Jun 10, 2026
bad01b8
Merge branch 'main' into refactor/llama-attention-hooks
rjzhb Jun 10, 2026
5ac5d79
fix(spec-decode): fall back when fused KV prewrite arg is None
rjzhb Jun 10, 2026
555252f
refactor(spec-decode): wrap Qwen3.5 NextN attention via base hooks
rjzhb Jun 10, 2026
f01a2ab
Merge remote-tracking branch 'upstream/main' into refactor/qwen-atten…
rjzhb Jun 11, 2026
fafc901
fix(qwen3.5-nextn): drop idle-mode early return in draft attn
rjzhb Jun 11, 2026
404279b
fix(spec-decode): pad decode seq_lens to spec_num_tokens
rjzhb Jun 12, 2026
c8f0632
Merge remote-tracking branch 'upstream/main' into refactor/qwen-atten…
rjzhb Jun 12, 2026
e9e39d0
Revert "fix(spec-decode): pad decode seq_lens to spec_num_tokens"
rjzhb Jun 12, 2026
ca10fca
fix(spec-decode): clamp catch-up trim result to avoid negative seq_lens
rjzhb Jun 12, 2026
263ef70
chore(qwen3.5-nextn): move sigmoid_mul import below first-party group
rjzhb Jun 12, 2026
4364323
chore(qwen3.5-nextn): drop isort skip markers, accept canonical order
rjzhb Jun 12, 2026
e2c032f
Merge branch 'main' into refactor/qwen-attention-hooks
LorrinWWW Jun 12, 2026
72ddde0
fix(pd): record draft layerwise cache step on EXTEND catch-up
rjzhb Jun 14, 2026
4d52f2f
Merge branch 'main' into refactor/qwen-attention-hooks
rjzhb Jun 14, 2026
2806ea4
fix(spec-decode): route draft attn via self.attn; guard step_counter
rjzhb Jun 14, 2026
b9645c4
docs(spec-decode): clarify draft catch-up attention comments
rjzhb Jun 14, 2026
af1ab65
refactor(spec-decode): centralize draft cache-step record via record_…
rjzhb Jun 15, 2026
f8c9490
refactor(attn): share pre/post cache-step record via record_cache_step
rjzhb Jun 17, 2026
d609578
refactor(attn): rename helper to record_pd_cache_step
rjzhb Jun 17, 2026
3290daf
Merge branch 'main' into refactor/qwen-attention-hooks
rjzhb Jun 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/tokenspeed/runtime/execution/drafter/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ForwardMode,
)
from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3
from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN
from tokenspeed.runtime.multimodal.inputs import maybe_substitute_mm_pad
from tokenspeed.runtime.utils import get_colorful_logger
from tokenspeed.runtime.utils.nvtx import nvtx_range
Expand Down Expand Up @@ -218,10 +219,13 @@ def _run_first_step(
draft_input, bs, draft_input.input_num_tokens
)
input_ids = maybe_substitute_mm_pad(input_ids, self.mm_pad_substitute_id)
# Llama Eagle3 narrows for any non-idle catch-up (EXTEND/MIXED/
# TARGET_VERIFY/DECODE); Qwen/DeepSeek keep is_decode() only.
# Llama Eagle3 and Qwen3.5 NextN narrow for any non-idle catch-up
# (EXTEND/MIXED/TARGET_VERIFY/DECODE); DeepSeek keeps is_decode() only.
draft_first_step_reduce = forward_mode.is_decode() or (
isinstance(self.draft_model_runner.model, LlamaForCausalLMEagle3)
isinstance(
self.draft_model_runner.model,
(LlamaForCausalLMEagle3, Qwen3_5ForConditionalGenerationNextN),
)
and not forward_mode.is_idle()
)

Expand Down
19 changes: 18 additions & 1 deletion python/tokenspeed/runtime/execution/model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from tokenspeed.runtime.execution.types import ModelExecutionResult
from tokenspeed.runtime.grammar.capturable_grammar import setup_grammar_step
from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput
from tokenspeed.runtime.models.llama_eagle3 import LlamaForCausalLMEagle3
from tokenspeed.runtime.models.qwen3_5_nextn import Qwen3_5ForConditionalGenerationNextN
from tokenspeed.runtime.sampling.backends.base import SamplingBackend
from tokenspeed.runtime.sampling.dp_sampling_config import (
DpSamplingRuntimeLimits,
Expand Down Expand Up @@ -1078,7 +1080,22 @@ def execute_idle_forward(
global_num_tokens=draft_global_num_tokens,
global_bs=global_bs,
all_decode_or_idle=all_decode_or_idle,
draft_first_step_reduce=(step_idx == 0 and all_decode_or_idle),
# Mirror the active-rank broaden in eagle.py: Llama Eagle3
# and Qwen3.5 NextN narrow for any non-idle catch-up, so the
# idle peer must size collectives the same way.
draft_first_step_reduce=(
step_idx == 0
and (
all_decode_or_idle
or isinstance(
self.drafter.draft_model_runner.model,
(
LlamaForCausalLMEagle3,
Qwen3_5ForConditionalGenerationNextN,
),
)
)
),
)
self.drafter.draft_model_runner.forward(
draft_ctx,
Expand Down
10 changes: 8 additions & 2 deletions python/tokenspeed/runtime/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _attn(
q_rope = self._fused_rope_kv_write(
positions, q, k, fused_kv_arg
).index_select(0, ctx.gather_ids)
return ctx.attn_backend.forward(
attn_output = ctx.attn_backend.forward(
q_rope,
None,
None,
Expand All @@ -110,6 +110,12 @@ def _attn(
ctx.bs,
save_kv_cache=False,
)
step_counter = ctx.attn_backend.step_counter
if step_counter is not None and not ctx.forward_mode.is_decode_or_idle():
# The backend call above intentionally uses DECODE metadata,
# bypassing EXTEND-side layerwise cache-step accounting.
step_counter.record_cache()
return attn_output
q, k = self.rotary_emb(positions, q, k)
return self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc).index_select(
0, ctx.gather_ids
Expand All @@ -126,7 +132,7 @@ def _apply_correction(self, ctx: ForwardContext) -> None:
correction = (
ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:]
).to(seq_lens_buf.dtype)
seq_lens_buf[num_extends : ctx.bs].sub_(correction)
seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1)


# ---------------------------------------------------------------------------
Expand Down
75 changes: 48 additions & 27 deletions python/tokenspeed/runtime/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,16 +710,13 @@ def _apply_qk_norm(
self.q_norm.variance_epsilon,
)

def self_attention(
def _project_qkv_rope(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
ctx: ForwardContext,
out_cache_loc: torch.Tensor,
) -> torch.Tensor:
"""Full attention forward pass."""
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""qkv_proj + split + rope (+ optional gate). ``gate`` is ``None`` when ``attn_output_gate=False``."""
qkv, _ = self.qkv_proj(hidden_states)

if self.attn_output_gate:
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1
Expand All @@ -737,23 +734,48 @@ def self_attention(
self.head_dim,
self.rotary_emb.rotary_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
return q, k, v, gate
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
return q, k, v, None

def _attn(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gate: torch.Tensor | None,
ctx: ForwardContext,
out_cache_loc: torch.Tensor,
) -> torch.Tensor:
"""Backend attention call + optional gate apply. Subclasses override."""
attn_output = self.attn(q, k, v, ctx, out_cache_loc)

if self.attn_output_gate:
if gate is not None:
sigmoid_mul(attn_output, gate)
return attn_output

if ctx.draft_first_step_reduce:
# Slice attn_output to [bs, H] so o_proj runs on live rows only.
attn_output = attn_output.index_select(0, ctx.gather_ids)

def self_attention(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
ctx: ForwardContext,
out_cache_loc: torch.Tensor,
) -> torch.Tensor:
"""Full attention forward pass."""
q, k, v, gate = self._project_qkv_rope(positions, hidden_states)
attn_output = self._attn(q, k, v, gate, ctx, out_cache_loc)
output, _ = self.o_proj(attn_output)
return output

def _maybe_narrow_residual(
self,
residual: torch.Tensor,
ctx: ForwardContext,
) -> torch.Tensor:
"""Hook: subclasses narrow residual to match a sliced attn output."""
return residual

def forward(
self,
positions: torch.Tensor,
Expand All @@ -778,9 +800,7 @@ def forward(
ctx=ctx,
out_cache_loc=out_cache_loc,
)
if ctx.draft_first_step_reduce:
# Gather residual to self_attention's [bs, H].
residual = residual.index_select(0, ctx.gather_ids)
residual = self._maybe_narrow_residual(residual, ctx)
hidden_states, residual = self.comm_manager.post_attn_reduce_norm(
hidden_states, residual, ctx
)
Expand Down Expand Up @@ -816,15 +836,12 @@ def forward_mlp(
return hidden_states


ALL_DECODER_LAYER_TYPES = {
"attention": Qwen3_5AttentionDecoderLayer,
"linear_attention": Qwen3_5LinearDecoderLayer,
}


class Qwen3_5ForCausalLM(nn.Module):
"""Qwen3.5 Model with support for dense variant."""

ATTENTION_LAYER_CLS: type = Qwen3_5AttentionDecoderLayer
LINEAR_LAYER_CLS: type = Qwen3_5LinearDecoderLayer

def __init__(
self,
config: Qwen3_5TextConfig,
Expand All @@ -849,10 +866,14 @@ def __init__(
tp_group=self.mapping.attn.tp_group,
)

# Decoder layers
layer_cls_by_type = {
"attention": self.ATTENTION_LAYER_CLS,
"linear_attention": self.LINEAR_LAYER_CLS,
}

def get_layer(idx: int, prefix: str):
layer_type = config.layers_block_type[idx]
layer_class = ALL_DECODER_LAYER_TYPES[layer_type]
layer_class = layer_cls_by_type[layer_type]
if layer_type == "attention":
prefix = add_prefix("self_attn", prefix)
else:
Expand Down
106 changes: 104 additions & 2 deletions python/tokenspeed/runtime/models/qwen3_5_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
from collections.abc import Iterable

import torch
from tokenspeed_kernel.ops.activation.triton import sigmoid_mul
Comment thread
rjzhb marked this conversation as resolved.
from torch import nn
from transformers import PretrainedConfig

from tokenspeed.runtime.distributed.mapping import Mapping
from tokenspeed.runtime.execution.context import ForwardContext
from tokenspeed.runtime.execution.forward_batch_info import ForwardMode
from tokenspeed.runtime.layers.layernorm import GemmaRMSNorm
from tokenspeed.runtime.layers.linear import ReplicatedLinear
from tokenspeed.runtime.layers.logits_processor import LogitsMetadata, LogitsProcessor
Expand All @@ -38,12 +40,112 @@
)
from tokenspeed.runtime.layers.vocab_parallel_embedding import ParallelLMHead
from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader
from tokenspeed.runtime.models.qwen3_5 import Qwen3_5ForCausalLM
from tokenspeed.runtime.models.qwen3_5 import (
Qwen3_5AttentionDecoderLayer,
Qwen3_5ForCausalLM,
)
from tokenspeed.runtime.utils import add_prefix

logger = logging.getLogger(__name__)


class Qwen3_5DraftAttentionDecoderLayer(Qwen3_5AttentionDecoderLayer):
"""NextN draft variant: skip dead catch-up rows on the first draft step.

On the first draft step the backend runs in DECODE mode with ``q`` sliced
to ``bs`` while ``self.attn`` still writes the full ``N`` rope-d KV rows
from the just-drafted tokens. Multi-step decode delegates to base.

MIXED catch-up requires a backend that populates a decode-slot metadata
under EXTEND/MIXED at draft init (e.g. trtllm-mha); MHA-family backends
that assert ``not is_mixed()`` at metadata init are not supported.
"""

def _attn(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gate: torch.Tensor | None,
ctx: ForwardContext,
out_cache_loc: torch.Tensor,
) -> torch.Tensor:
if ctx.accept_lengths is None:
return super()._attn(q, k, v, gate, ctx, out_cache_loc)

self._apply_correction(ctx)
q = q.index_select(0, ctx.gather_ids)
if gate is not None:
gate = gate.index_select(0, ctx.gather_ids)
attn_output = ctx.attn_backend.forward(
q,
k,
v,
Comment thread
rjzhb marked this conversation as resolved.
Outdated
self.attn,
out_cache_loc,
ctx.token_to_kv_pool,
ForwardMode.DECODE,
ctx.bs,
save_kv_cache=True,
)
step_counter = ctx.attn_backend.step_counter
if step_counter is not None and not ctx.forward_mode.is_decode_or_idle():
# The backend call above intentionally uses DECODE metadata, which
# bypasses the backend's EXTEND-side layerwise cache-step accounting.
step_counter.record_cache()
if gate is not None:
sigmoid_mul(attn_output, gate)
return attn_output

def _apply_correction(self, ctx: ForwardContext) -> None:
"""Trim decode rows' cache_seqlens by ``spec_num_tokens - accept_lengths``."""
seq_lens_buf = ctx.draft_seq_lens_buf
if seq_lens_buf is None or ctx.accept_lengths is None:
return
num_extends = ctx.num_extends
if num_extends >= ctx.bs:
return
correction = (
ctx.attn_backend.spec_num_tokens - ctx.accept_lengths[num_extends:]
).to(seq_lens_buf.dtype)
seq_lens_buf[num_extends : ctx.bs].sub_(correction).clamp_(min=1)

def _maybe_narrow_residual(
self,
residual: torch.Tensor,
ctx: ForwardContext,
) -> torch.Tensor:
if ctx.accept_lengths is None or ctx.forward_mode.is_idle():
return residual
return residual.index_select(0, ctx.gather_ids)


class Qwen3_5DraftForCausalLM(Qwen3_5ForCausalLM):
"""Causal LM with the draft-variant attention layer injected.

Restricted to single-layer drafts: ``_apply_correction`` mutates
``ctx.draft_seq_lens_buf`` in place and is not idempotent across layers.
A multi-layer draft would double-trim cache_seqlens. Lift the correction
out of the per-layer hook (e.g. into the drafter) before relaxing this.
"""

ATTENTION_LAYER_CLS: type = Qwen3_5DraftAttentionDecoderLayer

def __init__(
self,
config,
mapping,
quant_config=None,
prefix: str = "",
) -> None:
assert config.num_hidden_layers == 1, (
"Qwen3_5DraftForCausalLM requires num_hidden_layers == 1 "
f"(got {config.num_hidden_layers}); _apply_correction is not "
"idempotent across layers."
)
super().__init__(config, mapping, quant_config=quant_config, prefix=prefix)


class Qwen3_5ForConditionalGenerationNextN(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -74,7 +176,7 @@ def __init__(
self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
config.num_hidden_layers = 1
config.full_attention_interval = 1
self.model = Qwen3_5ForCausalLM(
self.model = Qwen3_5DraftForCausalLM(
config,
mapping=self.mapping,
quant_config=quant_config,
Expand Down
Loading