Skip to content

perf(gdn): remove gdn prefill unnecessary h_state copy#409

Merged
zhyncs merged 4 commits into
mainfrom
jjd/support_h_track
Jun 11, 2026
Merged

perf(gdn): remove gdn prefill unnecessary h_state copy#409
zhyncs merged 4 commits into
mainfrom
jjd/support_h_track

Conversation

@minedec

@minedec minedec commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Summary

Remove the previous redundant h-state conversion when persisting intermediate SSM states into the tracking pool (h-track and final-track) on the flashinfer GDN prefill fast-path. Instead of rebuilding an FLA-style h tensor from flashinfer's checkpoints, the tracking now indexes the native checkpoint buffer directly.

Approach:

  • Reuse flashinfer's native checkpoints; drop the FLA-h rebuild: The prefill wrapper no longer runs a host-side per-sequence loop that splices the initial state into slot 0 and drops the trailing checkpoint to fabricate an FLA-style h tensor. It now returns the raw checkpoint buffer plus cumulative start offsets (checkpoint_cu_starts), and the caller indexes with flashinfer-native offsets. This removes a full-state allocation/copy and a host loop that incurred .item() syncs.
  • Adjust index math to the new convention: Track indices now address the raw checkpoint buffer directly. The per-request entry count changes from ceil(L/C) to L//C, and the source index is shifted by -1 accordingly (FLA h[m] = ckpts[m-1]).
  • Use a fused kernel for final-track copies: The final-state movement of conv_states/ssm_states switches from per-tensor index-copy to a fused Triton kernel, which is also extended to support 2D non-contiguous per-layer views, relax the contiguity requirement, and cast indices to int32 automatically.

Test Plan

@minedec minedec requested a review from a team as a code owner June 10, 2026 09:11
@minedec minedec changed the title [WIP] perf(gdn): remove unnecessary h_state copy [WIP] perf(gdn): remove gdn prefill unnecessary h_state copy Jun 10, 2026
@minedec minedec requested a review from tuanzhangCS June 10, 2026 09:33
@lightseek-bot lightseek-bot changed the title [WIP] perf(gdn): remove gdn prefill unnecessary h_state copy perf(gdn): remove gdn prefill unnecessary h_state copy Jun 10, 2026
@minedec minedec force-pushed the jjd/support_h_track branch from 72475ac to 353ddf4 Compare June 11, 2026 02:31
@zhyncs zhyncs merged commit ae69e5e into main Jun 11, 2026
31 of 35 checks passed
@zhyncs zhyncs deleted the jjd/support_h_track branch June 11, 2026 03:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants