perf(gdn): fuse causal_conv1d and QKV split for GDN prefill#382
perf(gdn): fuse causal_conv1d and QKV split for GDN prefill#382elwhyjay wants to merge 2 commits into
Conversation
22ba473 to
057c156
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 057c1568b5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if pid_n < Q_DIM // BLOCK_N: | ||
| o_ptrs = q_ptr + g * Q_DIM + idx_feats | ||
| elif pid_n < (Q_DIM + K_DIM) // BLOCK_N: | ||
| o_ptrs = k_ptr + g * K_DIM + (idx_feats - Q_DIM) | ||
| else: | ||
| o_ptrs = v_ptr + g * V_DIM + (idx_feats - Q_DIM - K_DIM) |
There was a problem hiding this comment.
Route partial channel blocks before storing
For supported-looking GDN shapes where a Q or K region is not an exact multiple of BLOCK_N (for example head_dim=128 with an odd number of heads, which passes the later gdn_flashinfer.is_supported checks), this block-level test sends the boundary block to the next tensor. The remaining Q lanes in that block then store through v_ptr/k_ptr with negative offsets while mask_1d is still true, corrupting memory instead of matching the previous torch.split path. Either keep the fused path gated to dimensions divisible by 256 or route/store with per-lane masks like the standalone split kernel.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
wrapper now raises ValueError for shapes where Q/K/V dims aren't divisible by BLOCK_N=256. Unsupported shapes fall back to the two-kernel path.
… buffer Signed-off-by: leeyongjun <jqueen.astro@gmail.com>
…t kernel Signed-off-by: leeyongjun <jqueen.astro@gmail.com>
9e0d460 to
088663c
Compare
|
Resolved conflict. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 088663c728
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| head_v=head_v_dim, | ||
| ) | ||
| else: | ||
| query, key, value = causal_conv1d_qkv_split_gdn_prefill( |
There was a problem hiding this comment.
Fall back before invoking the 256-wide fused split
When need_h_track is false this now always calls causal_conv1d_qkv_split_gdn_prefill, whose wrapper raises unless each per-rank Q/K/V dimension is divisible by 256. Valid Qwen3.5 layouts with higher attention TP, e.g. the default 16 key heads at attn_tp_size=16 (Q_DIM=K_DIM=128, V_DIM=256), pass the later gdn_flashinfer.is_supported checks but will now fail during prefill instead of using the previous causal_conv1d_fn + split path. Fresh evidence beyond the earlier thread is that the current caller does not catch the wrapper's ValueError or branch around unsupported per-rank dimensions here.
Useful? React with 👍 / 👎.
What this does
In the GDN prefill path,
causal_conv1d_fnwrites its output to a staging buffer (conv_out, bf16, shape[conv_dim, T]), and thenfused_qkv_split_gdn_prefillreads that buffer to produce Q, K, V. This PR replaces those two kernel calls with a single Triton kernel that writes directly into Q/K/V, skipping the intermediate buffer entirely.The staging buffer grows with sequence length at T=8192 it is 96 MB, which is 76% of B200's L2 (126 MB). At that size the split kernel cannot reuse L2-cached data from the conv kernel, so there is a real DRAM round-trip cost.
Changes
gdn_qkv_split.py: added_causal_conv1d_qkv_split_fwd_kernelandcausal_conv1d_qkv_split_gdn_prefill. The kernel body is almost identical to_causal_conv1d_fwd_kernel, only the store path changes (three output pointers instead of one).hybrid_linear_attn.py: routesneed_h_track=Falseprefill through the fused call.need_h_track=Truekeeps the original two-call path unchanged.The Q/K/V boundary alignment is clean because Q_DIM = K_DIM = V_DIM = 2048, all exact multiples of BLOCK_N=256, so each Triton program falls entirely within one output region.
Perf (B200, bf16, single-seq)
Tests
Added
test/runtime/layers/test_causal_conv1d_qkv_split_fused.py: 22 tests covering varlen, initial conv state, bf16/fp16. All pass. Conv state update is also verified to be identical between the two paths.