Skip to content

perf(gdn): fuse causal_conv1d and QKV split for GDN prefill#382

Open
elwhyjay wants to merge 2 commits into
lightseekorg:mainfrom
elwhyjay:perf/conv1d-qkv-split-fusion
Open

perf(gdn): fuse causal_conv1d and QKV split for GDN prefill#382
elwhyjay wants to merge 2 commits into
lightseekorg:mainfrom
elwhyjay:perf/conv1d-qkv-split-fusion

Conversation

@elwhyjay

@elwhyjay elwhyjay commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

What this does

In the GDN prefill path, causal_conv1d_fn writes its output to a staging buffer (conv_out, bf16, shape [conv_dim, T]), and then fused_qkv_split_gdn_prefill reads that buffer to produce Q, K, V. This PR replaces those two kernel calls with a single Triton kernel that writes directly into Q/K/V, skipping the intermediate buffer entirely.

The staging buffer grows with sequence length at T=8192 it is 96 MB, which is 76% of B200's L2 (126 MB). At that size the split kernel cannot reuse L2-cached data from the conv kernel, so there is a real DRAM round-trip cost.

Changes

  • gdn_qkv_split.py: added _causal_conv1d_qkv_split_fwd_kernel and causal_conv1d_qkv_split_gdn_prefill. The kernel body is almost identical to _causal_conv1d_fwd_kernel, only the store path changes (three output pointers instead of one).
  • hybrid_linear_attn.py: routes need_h_track=False prefill through the fused call. need_h_track=True keeps the original two-call path unchanged.

The Q/K/V boundary alignment is clean because Q_DIM = K_DIM = V_DIM = 2048, all exact multiples of BLOCK_N=256, so each Triton program falls entirely within one output region.

Perf (B200, bf16, single-seq)

     T    sequential      fused    gain
   512      128.4µs     79.1µs   +38%
  2048      131.6µs     70.8µs   +46%
  4096      150.4µs     74.9µs   +50%
  8192      203.8µs    126.7µs   +38%
 16384      324.6µs    204.8µs   +37%

Tests

Added test/runtime/layers/test_causal_conv1d_qkv_split_fused.py: 22 tests covering varlen, initial conv state, bf16/fp16. All pass. Conv state update is also verified to be identical between the two paths.

@elwhyjay elwhyjay requested a review from a team as a code owner June 8, 2026 07:53
@elwhyjay elwhyjay force-pushed the perf/conv1d-qkv-split-fusion branch from 22ba473 to 057c156 Compare June 8, 2026 08:04

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 057c1568b5

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +530 to +535
if pid_n < Q_DIM // BLOCK_N:
o_ptrs = q_ptr + g * Q_DIM + idx_feats
elif pid_n < (Q_DIM + K_DIM) // BLOCK_N:
o_ptrs = k_ptr + g * K_DIM + (idx_feats - Q_DIM)
else:
o_ptrs = v_ptr + g * V_DIM + (idx_feats - Q_DIM - K_DIM)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Route partial channel blocks before storing

For supported-looking GDN shapes where a Q or K region is not an exact multiple of BLOCK_N (for example head_dim=128 with an odd number of heads, which passes the later gdn_flashinfer.is_supported checks), this block-level test sends the boundary block to the next tensor. The remaining Q lanes in that block then store through v_ptr/k_ptr with negative offsets while mask_1d is still true, corrupting memory instead of matching the previous torch.split path. Either keep the fused path gated to dimensions divisible by 256 or route/store with per-lane masks like the standalone split kernel.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrapper now raises ValueError for shapes where Q/K/V dims aren't divisible by BLOCK_N=256. Unsupported shapes fall back to the two-kernel path.

elwhyjay added 2 commits June 10, 2026 02:41
… buffer

Signed-off-by: leeyongjun <jqueen.astro@gmail.com>
…t kernel

Signed-off-by: leeyongjun <jqueen.astro@gmail.com>
@elwhyjay elwhyjay force-pushed the perf/conv1d-qkv-split-fusion branch from 9e0d460 to 088663c Compare June 10, 2026 02:42
@elwhyjay

Copy link
Copy Markdown
Contributor Author

Resolved conflict.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 088663c728

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

head_v=head_v_dim,
)
else:
query, key, value = causal_conv1d_qkv_split_gdn_prefill(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fall back before invoking the 256-wide fused split

When need_h_track is false this now always calls causal_conv1d_qkv_split_gdn_prefill, whose wrapper raises unless each per-rank Q/K/V dimension is divisible by 256. Valid Qwen3.5 layouts with higher attention TP, e.g. the default 16 key heads at attn_tp_size=16 (Q_DIM=K_DIM=128, V_DIM=256), pass the later gdn_flashinfer.is_supported checks but will now fail during prefill instead of using the previous causal_conv1d_fn + split path. Fresh evidence beyond the earlier thread is that the current caller does not catch the wrapper's ValueError or branch around unsupported per-rank dimensions here.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant