Skip to content

feat(mla): support custom tree masks in decode#338

Open
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:community/eagle-tree-mask-mla-20260602
Open

feat(mla): support custom tree masks in decode#338
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:community/eagle-tree-mask-mla-20260602

Conversation

@mesaleh

@mesaleh mesaleh commented Jun 2, 2026

Copy link
Copy Markdown

Summary

Adds optional FP8 custom tree-mask support to TokenSpeed MLA decode. The new path lets callers provide a flattened per-request custom_mask plus cmask_off offsets so tree-style speculative verification can use the TokenSpeed MLA kernel instead of falling back to dense/causal-only behavior.

What changed:

  • Add custom_mask / cmask_off kwargs to tokenspeed_mla_decode.
  • Extend the FP8 CuTe MLA decode ABI with mask pointer, offset pointer, and mask length.
  • Apply the custom mask in the softmax tail region, including folded S_q handling and non-tile-aligned K.
  • Keep FP16/BF16 on the old call signature and keep FP8 non-tree behavior on stable dummy mask tensors.
  • Forward mask kwargs through the runtime tokenspeed_mla attention backend.
  • Add a CPU mask-layout oracle and a manual GB200 kernel-vs-reference harness.

Validation

  • Extracted from GB200 Kimi EAGLE tree validation work. The TokenSpeed MLA tree-mask path was deployed in the integrated inference stack and exercised through server startup and decode serving without runtime errors or crashes.
  • The included tokenspeed-mla/test/microbench_tree_decode.py harness provides a reproducible GB200 check against an independent absorbed-MLA PyTorch reference, including non-tile-aligned K and batched mask-offset cases.
  • Static checks:
    • python3 -m py_compile tokenspeed-mla/python/tokenspeed_mla/mla_decode.py tokenspeed-mla/python/tokenspeed_mla/mla_decode_fp8.py python/tokenspeed/runtime/layers/attention/backends/tokenspeed_mla.py tokenspeed-mla/test/test_tree_mask_decode.py tokenspeed-mla/test/microbench_tree_decode.py
    • git diff --check

Review

Reviewed with an external collaborative code reviewer before pushing. Initial findings were fixed:

  • Added kernel-side bounds guarding with custom_mask_len to prevent malformed offsets from causing out-of-bounds mask reads.
  • Forwarded custom_mask / cmask_off through the runtime TokenSpeed MLA backend.

Second review returned LGTM.

@mesaleh mesaleh requested a review from a team as a code owner June 2, 2026 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant