Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8b8836e
remote split extend
borontion Jun 13, 2026
d506bbc
streamline import
borontion Jun 13, 2026
2fbd1e2
cleanup
borontion Jun 13, 2026
2e9eafe
update flashinfer prefill api
borontion Jun 14, 2026
feedffb
add comment splitter
borontion Jun 14, 2026
017d9e3
remove scheduler metadata
borontion Jun 14, 2026
20aabd6
remove the registered cudnn impl
borontion Jun 14, 2026
5fdc7df
add extend fallback
borontion Jun 14, 2026
76e98fc
support fp8
borontion Jun 14, 2026
803734b
update attention test
borontion Jun 14, 2026
351d492
remove
borontion Jun 14, 2026
baade05
gate flashinfer kernel and cuda on blackwell
borontion Jun 15, 2026
d7fc387
move kv cache management kernel to tokenspeed-kernel
borontion Jun 15, 2026
b5b5c04
support fp8 kv cache
borontion Jun 15, 2026
3d35211
drop moe backend
borontion Jun 15, 2026
bba4c8c
add cu len for kv
borontion Jun 15, 2026
90789b7
use auto moe and attn backend for qwen
borontion Jun 15, 2026
49c37b0
remove try-catch import
borontion Jun 15, 2026
b566d7d
fix fa4 numerics
borontion Jun 15, 2026
55931a7
fix test
borontion Jun 15, 2026
6187d9d
drop dead code
borontion Jun 15, 2026
9667fd8
fix
borontion Jun 15, 2026
2f4ebab
rename
borontion Jun 15, 2026
54193ff
drop fa4 test
borontion Jun 15, 2026
3d78eb0
refactor
borontion Jun 15, 2026
31fe731
remove decode scheduler metadata
borontion Jun 15, 2026
63e520c
update comments
borontion Jun 15, 2026
bc88c4f
cleanup
borontion Jun 15, 2026
4ffd66a
fix fp8 path
borontion Jun 15, 2026
d2bf4d1
remove redundant .contiguous()
borontion Jun 15, 2026
38fb5bb
remove tags
borontion Jun 15, 2026
9938db4
introduce attn plan
borontion Jun 16, 2026
8500865
inline kwargs
borontion Jun 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 144 additions & 168 deletions python/tokenspeed/runtime/layers/attention/backends/mha.py

Large diffs are not rendered by default.

98 changes: 14 additions & 84 deletions python/tokenspeed/runtime/layers/attention/backends/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,19 @@
from typing import TYPE_CHECKING

import torch
import triton
import triton.language as tl
from tokenspeed_kernel.ops.attention.flashinfer import (
trtllm_batch_context_with_kv_cache,
trtllm_batch_decode_with_kv_cache,
)
from tokenspeed_kernel.ops.kvcache.triton import (
fused_fp8_set_kv_buffer,
gather_page_table_with_padding,
)

from tokenspeed.runtime.configs.model_config import AttentionArch
from tokenspeed.runtime.execution.forward_batch_info import ForwardMode
from tokenspeed.runtime.layers.attention.backends.base import AttentionBackend
from tokenspeed.runtime.layers.attention.configs.mha import MHAConfig
from tokenspeed.runtime.layers.attention.kv_cache.trtllm_fp8_kv_kernel import (
fused_fp8_set_kv_buffer,
)
from tokenspeed.runtime.layers.attention.registry import register_backend
from tokenspeed.runtime.layers.common import fp8_cast_contiguous
from tokenspeed.runtime.utils import get_colorful_logger
Expand Down Expand Up @@ -565,37 +564,6 @@ def _init_multi_token_metadata_capture(
self.cuda_graph_prefill_metadata[bs] = metadata
self.forward_prefill_metadata = metadata

def _replay_gather_page_table(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
req_to_page: torch.Tensor,
) -> None:
"""Refresh cuda_graph_page_table[:bs] for the current replay step.

Replaces torch.index_select(req_to_page, 0, req_pool_indices, out=...).
The Triton kernel (1) skips reading padding columns of req_to_page
(cache-miss bound under large max_num_pages) and (2) overwrites stale
page IDs left in padding columns by previous replays where bs or
seq_lens were larger — keeping cuda_graph_page_table consistent.
"""
BLOCK_COLS = 128
grid = (bs, triton.cdiv(self.max_num_pages, BLOCK_COLS))
_gather_page_table_with_padding_kernel[grid](
req_to_page,
req_pool_indices,
seq_lens,
self.cuda_graph_page_table,
req_to_page.stride(0),
self.cuda_graph_page_table.stride(0),
self.max_num_pages,
self.page_size,
0, # dummy_slot — must match cuda_graph_page_table init (zeros)
BLOCK_COLS=BLOCK_COLS,
num_warps=4,
)

def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
Expand All @@ -612,7 +580,16 @@ def init_forward_metadata_replay_cuda_graph(

# cache_seqlens aliases seq_lens_buf; only page_table needs refresh.
if req_to_page is not None:
self._replay_gather_page_table(bs, req_pool_indices, seq_lens, req_to_page)
gather_page_table_with_padding(
req_to_page=req_to_page,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
out=self.cuda_graph_page_table,
bs=bs,
max_num_pages=self.max_num_pages,
page_size=self.page_size,
dummy_slot=0,
)

if bs in self.cuda_graph_prefill_metadata:
self.forward_prefill_metadata = self.cuda_graph_prefill_metadata[bs]
Expand All @@ -621,50 +598,3 @@ def init_forward_metadata_replay_cuda_graph(


register_backend("trtllm", {AttentionArch.MHA}, TRTLLMMHAAttnBackend)


# ---------------------------------------------------------------------------
# Triton kernel for cuda graph page table gather (replay path)
# ---------------------------------------------------------------------------
# Replaces torch.index_select(req_to_page, 0, req_pool_indices, out=...) which
# launches an ATen indexSelectSmallIndex kernel that (a) reads every column of
# the source row including padding (max_num_pages can be ~2048 for 128K context)
# and (b) suffers cache misses on the small-index gather pattern.
#
# This kernel only loads the actual valid pages (ceil(seq_len / page_size))
# and writes dummy_slot to padding columns, which both shrinks total reads
# (often by 10-100x) and overwrites any stale page IDs left from previous
# replays where bs or seq_lens were larger.


@triton.jit
def _gather_page_table_with_padding_kernel(
req_to_page_ptr, # [req_pool_size+1, src_stride0] int32
req_pool_indices_ptr, # [bs] int32 or int64
seq_lens_ptr, # [bs] int32 — KV length per req
out_ptr, # [max_bs, max_num_pages] int32
src_stride0, # row stride of req_to_page
out_stride0, # row stride of cuda_graph_page_table
max_num_pages: tl.constexpr,
page_size: tl.constexpr,
dummy_slot: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)

# Per-row valid page count = ceil(seq_len / page_size).
sl = tl.load(seq_lens_ptr + pid_row).to(tl.int32)
n_pages = (sl + page_size - 1) // page_size

col_offsets = pid_col * BLOCK_COLS + tl.arange(0, BLOCK_COLS)
in_bounds = col_offsets < max_num_pages
valid = col_offsets < n_pages

# Gather source row; out-of-range cols (padding) get dummy_slot via `other`.
req_idx = tl.load(req_pool_indices_ptr + pid_row).to(tl.int64)
src_addr = req_to_page_ptr + req_idx * src_stride0 + col_offsets
gathered = tl.load(src_addr, mask=valid & in_bounds, other=dummy_slot)

out_addr = out_ptr + pid_row * out_stride0 + col_offsets
tl.store(out_addr, gathered, mask=in_bounds)
Loading
Loading