Skip to content

fix(mla): apply MTP causal mask in BF16/FP16 decode kernel#469

Draft
zcnrex wants to merge 1 commit into
lightseekorg:mainfrom
zcnrex:fix/mla-decode-bf16-causal-mask
Draft

fix(mla): apply MTP causal mask in BF16/FP16 decode kernel#469
zcnrex wants to merge 1 commit into
lightseekorg:mainfrom
zcnrex:fix/mla-decode-bf16-causal-mask

Conversation

@zcnrex

@zcnrex zcnrex commented Jun 17, 2026

Copy link
Copy Markdown

Problem

The BF16/FP16 MLA decode kernel (BlackwellMultiHeadLatentAttentionForwardFP16) applies only a plain K-bound mask (col >= K) and omits the per-row spec-decoding (MTP) causal bound that the FP8 kernel already implements. For seq_len_q > 1 (MTP / spec-decode), every query row attends to up to seq_len_q - 1 future keys it should not see → wrong output, increasingly so at short context (error ≈ (seq_len_q-1)/K).

The kernel doesn't even accept is_causal — the wrapper only wires it for the fp8 path — so causal_mask=True is a silent no-op for bf16.

Concretely (B200, h=16, bs=10, bf16, vs an fp32 torch MTP reference):

q seq before (this kernel) reference
8 2048 29.74 dB SNR, max_abs 2.06e-1 54.84 dB
8 50000 52.64 dB 55.44 dB

(The FP8 kernel is correct here — only fp16/bf16 was affected.)

Fix

Port the FP8 kernel's causal machinery (mla_decode_fp8.py) into the fp16 kernel, keeping the same conventions:

  • __init__: add is_causal, num_heads, seq_len_q.
  • compute(): replace the single-K-bound-tile loop with the is_causal-gated 3-phase mask_tile_count loop (unmasked bulk → masked intermediate → masked final), covering the causal region's tile-boundary-crossing case.
  • softmax() (both sm_100 and sm_103 branches): replace col >= K with the per-row k_bound = K - (seq_len_q - 1) + q_tok (fold-aware q_tok), gated by is_causal (non-causal falls back to the plain K-bound). Rename the flag is_last_tileapply_mask to match fp8.
  • Wrapper: wire is_causal/num_heads/seq_len_q through for the fp16 path too (previously fp8-only). DCP (cp_world) stays fp8-only.

Validation

B200, h=16, bs=10, bf16, vs fp32 torch MTP reference — bf16 output is now bit-identical to the (correct) reference at both q=4 and q=8:

q seq after max_abs
8 2048 54.84 dB 2.043e-3
8 50000 55.44 dB 1.028e-3
4 2048 54.85 dB 3.425e-3
4 50000 55.44 dB 9.691e-4

q_len=1 (where MTP causal ≡ plain K-bound) is unaffected.

Note: independent of, and complementary to, the LSE-guard fix (the bf16 default return_lse=False path also needs that guard to run without crashing). Correctness here was validated on a build carrying both fixes.

🤖 Generated with Claude Code

BlackwellMultiHeadLatentAttentionForwardFP16 applied only a plain K-bound
mask (col >= K), omitting the per-row spec-decoding (MTP) causal bound that
the FP8 kernel already implements. For seq_len_q > 1 every query row attended
to up to seq_len_q-1 future keys -> wrong output, increasingly so at short
context (error ~ (seq_len_q-1)/K). The kernel didn't even accept is_causal,
so causal_mask=True was a silent no-op for bf16.

Port the FP8 kernel's causal machinery (mla_decode_fp8.py): add is_causal/
num_heads/seq_len_q to __init__; the is_causal-gated 3-phase mask_tile_count
loop in compute(); and the per-row k_bound = K-(seq_len_q-1)+q_tok (fold-aware)
in softmax() (both arch branches). Wire is_causal/num_heads/seq_len_q through
the wrapper for the fp16 path. DCP/cp_world stays fp8-only.

Verified on B200 (h=16, bs=10, bf16, vs fp32 torch MTP reference): bf16 output
is now bit-identical to the FP8-kernel-style causal reference at q=4 and q=8
(e.g. q=8 seq=2048: SNR 29.74 -> 54.84 dB, matching max_abs).

AI-assisted (Claude Code).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant