fix(mla): apply MTP causal mask in BF16/FP16 decode kernel#469
Draft
zcnrex wants to merge 1 commit into
Draft
Conversation
BlackwellMultiHeadLatentAttentionForwardFP16 applied only a plain K-bound mask (col >= K), omitting the per-row spec-decoding (MTP) causal bound that the FP8 kernel already implements. For seq_len_q > 1 every query row attended to up to seq_len_q-1 future keys -> wrong output, increasingly so at short context (error ~ (seq_len_q-1)/K). The kernel didn't even accept is_causal, so causal_mask=True was a silent no-op for bf16. Port the FP8 kernel's causal machinery (mla_decode_fp8.py): add is_causal/ num_heads/seq_len_q to __init__; the is_causal-gated 3-phase mask_tile_count loop in compute(); and the per-row k_bound = K-(seq_len_q-1)+q_tok (fold-aware) in softmax() (both arch branches). Wire is_causal/num_heads/seq_len_q through the wrapper for the fp16 path. DCP/cp_world stays fp8-only. Verified on B200 (h=16, bs=10, bf16, vs fp32 torch MTP reference): bf16 output is now bit-identical to the FP8-kernel-style causal reference at q=4 and q=8 (e.g. q=8 seq=2048: SNR 29.74 -> 54.84 dB, matching max_abs). AI-assisted (Claude Code). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The BF16/FP16 MLA decode kernel (
BlackwellMultiHeadLatentAttentionForwardFP16) applies only a plain K-bound mask (col >= K) and omits the per-row spec-decoding (MTP) causal bound that the FP8 kernel already implements. Forseq_len_q > 1(MTP / spec-decode), every query row attends to up toseq_len_q - 1future keys it should not see → wrong output, increasingly so at short context (error ≈(seq_len_q-1)/K).The kernel doesn't even accept
is_causal— the wrapper only wires it for the fp8 path — socausal_mask=Trueis a silent no-op for bf16.Concretely (B200, h=16, bs=10, bf16, vs an fp32 torch MTP reference):
(The FP8 kernel is correct here — only fp16/bf16 was affected.)
Fix
Port the FP8 kernel's causal machinery (
mla_decode_fp8.py) into the fp16 kernel, keeping the same conventions:__init__: addis_causal,num_heads,seq_len_q.compute(): replace the single-K-bound-tile loop with theis_causal-gated 3-phasemask_tile_countloop (unmasked bulk → masked intermediate → masked final), covering the causal region's tile-boundary-crossing case.softmax()(bothsm_100andsm_103branches): replacecol >= Kwith the per-rowk_bound = K - (seq_len_q - 1) + q_tok(fold-awareq_tok), gated byis_causal(non-causal falls back to the plain K-bound). Rename the flagis_last_tile→apply_maskto match fp8.is_causal/num_heads/seq_len_qthrough for the fp16 path too (previously fp8-only). DCP (cp_world) stays fp8-only.Validation
B200, h=16, bs=10, bf16, vs fp32 torch MTP reference — bf16 output is now bit-identical to the (correct) reference at both q=4 and q=8:
q_len=1(where MTP causal ≡ plain K-bound) is unaffected.🤖 Generated with Claude Code