Skip to content

fix(runtime): use scattered token counts for MoE RSAG#407

Closed
Williams500 wants to merge 1 commit into
lightseekorg:mainfrom
Williams500:fix/moe-rsag-token-counts-405
Closed

fix(runtime): use scattered token counts for MoE RSAG#407
Williams500 wants to merge 1 commit into
lightseekorg:mainfrom
Williams500:fix/moe-rsag-token-counts-405

Conversation

@Williams500

@Williams500 Williams500 commented Jun 10, 2026

Copy link
Copy Markdown

Summary

Fixes #405.

When attention TP and MoE TP differ, post_attn_comm reduce-scatters hidden states across the attention TP group before pre_moe_comm. The MoE token-aware all-gather must therefore size its per-rank rows from the post-attention scattered layout, not from the original per-rank global_num_tokens.

This changes moe_tp_ep_group_scattered_num_tokens to reuse the already-computed attention-scattered counts and select the ranks in the current MoE tp_ep_group.

Test Plan

  • Added a CPU-only regression for the token-count invariant from [Bug] MoE all-gather fails when attention TP ≠ MoE TP (attn tp=2 dp=2, moe tp=4) #405.
  • Added a CPU-only regression for attention CP ranks in MoE groups, covering Codex review feedback.
  • Ran the regression locally with lightweight stubs for unavailable GPU/runtime packages: 14 passed, 1 deselected.
  • Ran mapping tests locally with the same stubs: 71 passed.
  • python -m py_compile python/tokenspeed/runtime/distributed/comm_manager.py test/runtime/distributed/test_draft_moe_capture_global_bs.py
  • git diff --check
  • pre-commit run --all-files

@Williams500 Williams500 marked this pull request as ready for review June 10, 2026 03:55
@Williams500 Williams500 requested a review from a team as a code owner June 10, 2026 03:55

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a528a56b2e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

start = self.mapping.moe.dp_rank * tp_ep_size
return list(global_counts[start : start + tp_ep_size])
scattered = self.scattered_num_tokens(ctx)
return [scattered[rank] for rank in self.mapping.moe.tp_ep_group]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve counts for CP ranks in MoE groups

When context parallelism is enabled, scattered_num_tokens() only builds attn.dp_size * attn.tp_size entries, not one entry per world rank, while tp_ep_group still contains world-rank ids. In a valid CP+MoE RSAG setup such as world_size=4, attn_tp_size=1, attn_cp_size=4, and moe_tp_size=4, this list has length 1 but the MoE group is (0, 1, 2, 3), so pre_moe_comm/post_moe_comm crashes with IndexError as soon as global token counts are present. Please account for CP ranks (or avoid indexing the attention-scattered list by world rank) before using it for the MoE group.

Useful? React with 👍 / 👎.

Signed-off-by: william <williamto2048@gmail.com>
@Williams500 Williams500 force-pushed the fix/moe-rsag-token-counts-405 branch from a528a56 to 4301355 Compare June 10, 2026 05:23
@rjzhb

rjzhb commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Hi, thanks for the fix — the direction is right, but it still breaks under CP + DP(Although current version CP doesn't work. In Future will add CP.): your lookup divides ranks by tp*cp while scattered_num_tokens() builds the table with a tp-only stride, so DP groups read the wrong slots (e.g. tp1/cp2/dp4 with [3,3,5,5,7,7,9,9] returns [3,3,3,3,5,5,5,5]). Also the compiled-layer path has a verbatim copy of the old MOE_TP_EP slicing in models/base/comm_ops.py, so some models still hit the exact same #405 crash. I will open new PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] MoE all-gather fails when attention TP ≠ MoE TP (attn tp=2 dp=2, moe tp=4)

3 participants