fix(runtime): use scattered token counts for MoE RSAG#407
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a528a56b2e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| start = self.mapping.moe.dp_rank * tp_ep_size | ||
| return list(global_counts[start : start + tp_ep_size]) | ||
| scattered = self.scattered_num_tokens(ctx) | ||
| return [scattered[rank] for rank in self.mapping.moe.tp_ep_group] |
There was a problem hiding this comment.
Preserve counts for CP ranks in MoE groups
When context parallelism is enabled, scattered_num_tokens() only builds attn.dp_size * attn.tp_size entries, not one entry per world rank, while tp_ep_group still contains world-rank ids. In a valid CP+MoE RSAG setup such as world_size=4, attn_tp_size=1, attn_cp_size=4, and moe_tp_size=4, this list has length 1 but the MoE group is (0, 1, 2, 3), so pre_moe_comm/post_moe_comm crashes with IndexError as soon as global token counts are present. Please account for CP ranks (or avoid indexing the attention-scattered list by world rank) before using it for the MoE group.
Useful? React with 👍 / 👎.
Signed-off-by: william <williamto2048@gmail.com>
a528a56 to
4301355
Compare
|
Hi, thanks for the fix — the direction is right, but it still breaks under CP + DP(Although current version CP doesn't work. In Future will add CP.): your lookup divides ranks by tp*cp while scattered_num_tokens() builds the table with a tp-only stride, so DP groups read the wrong slots (e.g. tp1/cp2/dp4 with [3,3,5,5,7,7,9,9] returns [3,3,3,3,5,5,5,5]). Also the compiled-layer path has a verbatim copy of the old MOE_TP_EP slicing in models/base/comm_ops.py, so some models still hit the exact same #405 crash. I will open new PR. |
Summary
Fixes #405.
When attention TP and MoE TP differ,
post_attn_commreduce-scatters hidden states across the attention TP group beforepre_moe_comm. The MoE token-aware all-gather must therefore size its per-rank rows from the post-attention scattered layout, not from the original per-rankglobal_num_tokens.This changes
moe_tp_ep_group_scattered_num_tokensto reuse the already-computed attention-scattered counts and select the ranks in the current MoEtp_ep_group.Test Plan
14 passed, 1 deselected.71 passed.python -m py_compile python/tokenspeed/runtime/distributed/comm_manager.py test/runtime/distributed/test_draft_moe_capture_global_bs.pygit diff --checkpre-commit run --all-files