Skip to content

Commit 215acaf

Browse files
tuanzhangCStuanzhangCS
andauthored
fix pd decode TP4->DP4EP4 bugs (#448)
Co-authored-by: tuanzhangCS <tuan@lightseed.org>
1 parent d130da2 commit 215acaf

5 files changed

Lines changed: 226 additions & 17 deletions

File tree

python/tokenspeed/runtime/engine/event_loop.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ def calc_l3_query_hashes(scheduler, tokens: list[int]) -> list[str]:
108108
_PAUSED_IDLE_SLEEP_S = 0.001
109109

110110

111+
def _forward_op_executes_model_forward(forward_op, *, is_disagg_decode: bool) -> bool:
112+
"""Return whether ``forward_op`` will enter the model forward path.
113+
114+
On decode-side PD, EXTEND ops only start remote KV receive; the model
115+
forward runs after the remote prefill completes and the scheduler advances
116+
the request into decode. Treating those EXTEND ops as model work makes
117+
idle DP ranks enter dummy collectives that the active rank will not match.
118+
"""
119+
if forward_op is None:
120+
return False
121+
if sum(forward_op.input_lengths) <= 0:
122+
return False
123+
if is_disagg_decode and forward_op.num_extends() > 0:
124+
return False
125+
return True
126+
127+
111128
class _NullSender:
112129
"""No-op ZMQ sender for non-rank-0 workers."""
113130

@@ -1123,9 +1140,13 @@ def _dp_sync_and_check(self, forward_op) -> DpForwardMetadata:
11231140
"""
11241141
import torch.distributed as dist
11251142

1126-
num_tokens = sum(forward_op.input_lengths) if forward_op is not None else 0
1127-
batch_size = len(forward_op.request_ids) if forward_op is not None else 0
1128-
if forward_op is None:
1143+
executes_model_forward = _forward_op_executes_model_forward(
1144+
forward_op,
1145+
is_disagg_decode=isinstance(self.pd_kv_transfer, DisaggDecodeExecutor),
1146+
)
1147+
num_tokens = sum(forward_op.input_lengths) if executes_model_forward else 0
1148+
batch_size = len(forward_op.request_ids) if executes_model_forward else 0
1149+
if not executes_model_forward:
11291150
forward_mode = ForwardMode.IDLE
11301151
else:
11311152
forward_mode = ForwardMode.from_num_extends(

python/tokenspeed/runtime/pd/mooncake/receiver.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ def _build_buffer_layout_pair(
164164
prefill_tp_size: int,
165165
decode_tp_size: int,
166166
):
167+
"""Build compatible logical layouts for one prefill/decode buffer pair.
168+
169+
Besides normal TP sharding and fully replicated buffers, this handles GQA
170+
KV caches where prefill TP is larger than the number of distinct KV heads.
171+
In that case multiple prefill TP ranks carry the same KV head, so the
172+
transfer plan uses one representative rank from each replica group.
173+
"""
174+
167175
if prefill_unit_len != decode_unit_len:
168176
raise ValueError(
169177
f"prefill/decode unit sizes differ for {buffer_kind.value}: "
@@ -184,12 +192,23 @@ def _build_buffer_layout_pair(
184192
decode_local_units = decode_item_len // decode_unit_len
185193
prefill_global_units = prefill_local_units * prefill_tp_size
186194
decode_global_units = decode_local_units * decode_tp_size
195+
prefill_tp_replica_group_size = 1
187196
if prefill_global_units == decode_global_units:
188197
logical_axis = sharded_axis
189198
logical_size = decode_global_units
190199
elif prefill_item_len == decode_item_len:
191200
logical_axis = "replicated"
192201
logical_size = decode_local_units
202+
elif (
203+
sharded_axis == "kv_head"
204+
and decode_global_units % prefill_local_units == 0
205+
and decode_global_units // prefill_local_units <= prefill_tp_size
206+
and prefill_tp_size % (decode_global_units // prefill_local_units) == 0
207+
):
208+
logical_axis = sharded_axis
209+
logical_size = decode_global_units
210+
prefill_distinct_tp_size = decode_global_units // prefill_local_units
211+
prefill_tp_replica_group_size = prefill_tp_size // prefill_distinct_tp_size
193212
else:
194213
raise ValueError(
195214
f"unsupported heterogeneous TP buffer layout for {buffer_kind.value}: "
@@ -207,6 +226,7 @@ def _build_buffer_layout_pair(
207226
page_size=1,
208227
bytes_per_logical_unit=decode_unit_len,
209228
item_stride_bytes=prefill_item_len,
229+
tp_replica_group_size=prefill_tp_replica_group_size,
210230
),
211231
BufferLayout(
212232
buffer_index=buffer_index,

python/tokenspeed/runtime/pd/transfer_plan.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,21 @@ def tp_size_per_dp(self) -> int:
6161

6262
@dataclass(frozen=True)
6363
class BufferLayout:
64+
"""Logical layout for one cache/state buffer.
65+
66+
``tp_replica_group_size`` describes TP ranks that hold the same logical
67+
shard. It is used by GQA/MQA-style KV caches when the prefill TP size is
68+
larger than the number of distinct KV heads.
69+
"""
70+
6471
buffer_index: int
6572
buffer_kind: BufferKind
6673
logical_axis: Literal["kv_head", "state_channel", "replicated"]
6774
logical_size: int
6875
page_size: int
6976
bytes_per_logical_unit: int
7077
item_stride_bytes: int
78+
tp_replica_group_size: int = 1
7179

7280
def __post_init__(self):
7381
if self.logical_size <= 0:
@@ -78,6 +86,8 @@ def __post_init__(self):
7886
raise UnsupportedPDLayoutError("bytes_per_logical_unit must be positive")
7987
if self.item_stride_bytes <= 0:
8088
raise UnsupportedPDLayoutError("item_stride_bytes must be positive")
89+
if self.tp_replica_group_size <= 0:
90+
raise UnsupportedPDLayoutError("tp_replica_group_size must be positive")
8191

8292

8393
@dataclass(frozen=True)
@@ -216,7 +226,9 @@ def plan_for_decode_rank(self, decode_rank: int) -> RankTransferPlan:
216226
target_dp_group = decode_rank // decode_tp_size
217227
decode_tp_rank = decode_rank % decode_tp_size
218228

219-
if self.prefill_layout.tp_size_per_dp == decode_tp_size:
229+
if self._can_use_identity_plan() and (
230+
self.prefill_layout.tp_size_per_dp == decode_tp_size
231+
):
220232
prefill_rank = (
221233
target_dp_group * self.prefill_layout.tp_size_per_dp + decode_tp_rank
222234
)
@@ -261,19 +273,25 @@ def plan_for_decode_rank(self, decode_rank: int) -> RankTransferPlan:
261273
fragments.setdefault(prefill_rank, []).append(fragment)
262274
continue
263275

264-
decode_interval = self._rank_interval(
265-
decode_buffer.logical_size, decode_tp_size, decode_tp_rank
276+
decode_interval = self._rank_interval_for_buffer(
277+
decode_buffer,
278+
self.decode_layout,
279+
decode_tp_rank,
266280
)
281+
if decode_interval is None:
282+
continue
267283
for prefill_tp_rank in range(self.prefill_layout.tp_size_per_dp):
268284
prefill_rank = (
269285
target_dp_group * self.prefill_layout.tp_size_per_dp
270286
+ prefill_tp_rank
271287
)
272-
prefill_interval = self._rank_interval(
273-
prefill_buffer.logical_size,
274-
self.prefill_layout.tp_size_per_dp,
288+
prefill_interval = self._rank_interval_for_buffer(
289+
prefill_buffer,
290+
self.prefill_layout,
275291
prefill_tp_rank,
276292
)
293+
if prefill_interval is None:
294+
continue
277295
intersection = prefill_interval.intersect(decode_interval)
278296
if intersection is None:
279297
continue
@@ -341,11 +359,29 @@ def _validate_alignment(self) -> None:
341359
for buffer in buffers:
342360
if buffer.logical_axis == "replicated":
343361
continue
344-
if buffer.logical_size % layout.tp_size_per_dp != 0:
362+
if layout.tp_size_per_dp % buffer.tp_replica_group_size != 0:
363+
raise UnsupportedPDLayoutError(
364+
"tp replica group must divide TP size for "
365+
f"buffer_kind={buffer.buffer_kind.value}: "
366+
f"tp_size_per_dp={layout.tp_size_per_dp}, "
367+
f"tp_replica_group_size={buffer.tp_replica_group_size}"
368+
)
369+
effective_tp_size = (
370+
layout.tp_size_per_dp // buffer.tp_replica_group_size
371+
)
372+
if buffer.logical_size % effective_tp_size != 0:
345373
raise UnsupportedPDLayoutError(
346374
"non-aligned TP heterogeneous mapping for "
347375
f"buffer_kind={buffer.buffer_kind.value}: logical_size="
348-
f"{buffer.logical_size}, tp_size_per_dp={layout.tp_size_per_dp}"
376+
f"{buffer.logical_size}, effective_tp_size={effective_tp_size}"
377+
)
378+
item_units = buffer.item_stride_bytes // buffer.bytes_per_logical_unit
379+
required_units = buffer.logical_size // effective_tp_size
380+
if item_units < required_units:
381+
raise UnsupportedPDLayoutError(
382+
"buffer item is smaller than its logical shard for "
383+
f"buffer_kind={buffer.buffer_kind.value}: item_units="
384+
f"{item_units}, required_units={required_units}"
349385
)
350386

351387
def _calc_source_fanout(self) -> dict[int, int]:
@@ -370,17 +406,21 @@ def _calc_source_fanout(self) -> dict[int, int]:
370406
intersected_prefill_ranks.add(prefill_rank)
371407
continue
372408

373-
decode_interval = self._rank_interval(
374-
decode_buffer.logical_size,
375-
self.decode_layout.tp_size_per_dp,
409+
decode_interval = self._rank_interval_for_buffer(
410+
decode_buffer,
411+
self.decode_layout,
376412
decode_tp_rank,
377413
)
414+
if decode_interval is None:
415+
continue
378416
for prefill_tp_rank in range(self.prefill_layout.tp_size_per_dp):
379-
prefill_interval = self._rank_interval(
380-
prefill_buffer.logical_size,
381-
self.prefill_layout.tp_size_per_dp,
417+
prefill_interval = self._rank_interval_for_buffer(
418+
prefill_buffer,
419+
self.prefill_layout,
382420
prefill_tp_rank,
383421
)
422+
if prefill_interval is None:
423+
continue
384424
if prefill_interval.intersect(decode_interval) is None:
385425
continue
386426
prefill_rank = (
@@ -392,12 +432,34 @@ def _calc_source_fanout(self) -> dict[int, int]:
392432
fanout[prefill_rank] += 1
393433
return fanout
394434

435+
def _can_use_identity_plan(self) -> bool:
436+
return all(
437+
prefill_buffer.tp_replica_group_size == 1
438+
and decode_buffer.tp_replica_group_size == 1
439+
for prefill_buffer, decode_buffer in zip(
440+
self.prefill_buffers, self.decode_buffers
441+
)
442+
)
443+
395444
@staticmethod
396445
def _rank_interval(logical_size: int, tp_size: int, tp_rank: int) -> _Interval:
397446
local_size = logical_size // tp_size
398447
start = tp_rank * local_size
399448
return _Interval(start, start + local_size)
400449

450+
@staticmethod
451+
def _rank_interval_for_buffer(
452+
buffer: BufferLayout, layout: ParallelLayout, tp_rank: int
453+
) -> _Interval | None:
454+
replica_group_size = buffer.tp_replica_group_size
455+
if tp_rank % replica_group_size != 0:
456+
return None
457+
effective_tp_size = layout.tp_size_per_dp // replica_group_size
458+
effective_tp_rank = tp_rank // replica_group_size
459+
return PDTransferPlanner._rank_interval(
460+
buffer.logical_size, effective_tp_size, effective_tp_rank
461+
)
462+
401463
@staticmethod
402464
def _replicated_source_tp_rank(
403465
prefill_tp_size: int, decode_tp_size: int, decode_tp_rank: int
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2026 LightSeek Foundation
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy
4+
# of this software and associated documentation files (the "Software"), to deal
5+
# in the Software without restriction, including without limitation the rights
6+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
# copies of the Software, and to permit persons to whom the Software is
8+
# furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in
11+
# all copies or substantial portions of the Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
# SOFTWARE.
20+
21+
from tokenspeed.runtime.engine.event_loop import _forward_op_executes_model_forward
22+
23+
24+
class FakeForwardOp:
25+
def __init__(self, *, input_lengths, request_ids=None, num_extends=0):
26+
self.input_lengths = input_lengths
27+
self.request_ids = request_ids or [
28+
f"req-{i}" for i in range(len(input_lengths))
29+
]
30+
self._num_extends = num_extends
31+
32+
def num_extends(self):
33+
return self._num_extends
34+
35+
36+
def test_pd_decode_extend_only_does_not_require_idle_forward():
37+
# Decode-side PD EXTEND starts KV receive only; no model collectives run on
38+
# the active DP rank yet, so idle DP ranks must not enter dummy forward.
39+
op = FakeForwardOp(input_lengths=[17], num_extends=1)
40+
41+
assert not _forward_op_executes_model_forward(op, is_disagg_decode=True)
42+
43+
44+
def test_pd_decode_decode_step_requires_idle_forward():
45+
op = FakeForwardOp(input_lengths=[1], num_extends=0)
46+
47+
assert _forward_op_executes_model_forward(op, is_disagg_decode=True)
48+
49+
50+
def test_non_pd_extend_still_executes_model_forward():
51+
op = FakeForwardOp(input_lengths=[17], num_extends=1)
52+
53+
assert _forward_op_executes_model_forward(op, is_disagg_decode=False)
54+
55+
56+
def test_zero_token_forward_op_is_not_model_work():
57+
op = FakeForwardOp(input_lengths=[0], num_extends=1)
58+
59+
assert not _forward_op_executes_model_forward(op, is_disagg_decode=False)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from tokenspeed.runtime.pd.mooncake.receiver import _build_buffer_layout_pair
2+
from tokenspeed.runtime.pd.transfer_plan import (
3+
BufferKind,
4+
ParallelLayout,
5+
PDTransferPlanner,
6+
)
7+
8+
9+
def test_replicated_prefill_kv_heads_transfer_to_decode_full_kv_heads():
10+
prefill_buffer, decode_buffer = _build_buffer_layout_pair(
11+
buffer_index=0,
12+
buffer_kind=BufferKind.TARGET_K,
13+
sharded_axis="kv_head",
14+
prefill_item_len=16_384,
15+
decode_item_len=32_768,
16+
prefill_unit_len=256,
17+
decode_unit_len=256,
18+
prefill_tp_size=4,
19+
decode_tp_size=1,
20+
)
21+
22+
assert prefill_buffer.logical_size == 128
23+
assert prefill_buffer.tp_replica_group_size == 2
24+
assert decode_buffer.logical_size == 128
25+
assert decode_buffer.tp_replica_group_size == 1
26+
27+
planner = PDTransferPlanner(
28+
prefill_layout=ParallelLayout(role="prefill", world_size=4),
29+
decode_layout=ParallelLayout(role="decode", world_size=1),
30+
prefill_buffers=(prefill_buffer,),
31+
decode_buffers=(decode_buffer,),
32+
)
33+
plan = planner.plan_for_decode_rank(0)
34+
35+
assert plan.plan_kind == "fragmented"
36+
assert plan.target_prefill_ranks == (0, 2)
37+
assert plan.required_prefill_response_num == 2
38+
assert plan.required_dst_info_num_by_prefill_rank == {0: 1, 2: 1}
39+
40+
first_head = plan.fragments_by_prefill_rank[0][0]
41+
second_head = plan.fragments_by_prefill_rank[2][0]
42+
assert first_head.src_byte_offset == 0
43+
assert first_head.dst_byte_offset == 0
44+
assert first_head.bytes_per_page == 16_384
45+
assert second_head.src_byte_offset == 0
46+
assert second_head.dst_byte_offset == 16_384
47+
assert second_head.bytes_per_page == 16_384

0 commit comments

Comments
 (0)