@@ -61,13 +61,21 @@ def tp_size_per_dp(self) -> int:
6161
6262@dataclass (frozen = True )
6363class BufferLayout :
64+ """Logical layout for one cache/state buffer.
65+
66+ ``tp_replica_group_size`` describes TP ranks that hold the same logical
67+ shard. It is used by GQA/MQA-style KV caches when the prefill TP size is
68+ larger than the number of distinct KV heads.
69+ """
70+
6471 buffer_index : int
6572 buffer_kind : BufferKind
6673 logical_axis : Literal ["kv_head" , "state_channel" , "replicated" ]
6774 logical_size : int
6875 page_size : int
6976 bytes_per_logical_unit : int
7077 item_stride_bytes : int
78+ tp_replica_group_size : int = 1
7179
7280 def __post_init__ (self ):
7381 if self .logical_size <= 0 :
@@ -78,6 +86,8 @@ def __post_init__(self):
7886 raise UnsupportedPDLayoutError ("bytes_per_logical_unit must be positive" )
7987 if self .item_stride_bytes <= 0 :
8088 raise UnsupportedPDLayoutError ("item_stride_bytes must be positive" )
89+ if self .tp_replica_group_size <= 0 :
90+ raise UnsupportedPDLayoutError ("tp_replica_group_size must be positive" )
8191
8292
8393@dataclass (frozen = True )
@@ -216,7 +226,9 @@ def plan_for_decode_rank(self, decode_rank: int) -> RankTransferPlan:
216226 target_dp_group = decode_rank // decode_tp_size
217227 decode_tp_rank = decode_rank % decode_tp_size
218228
219- if self .prefill_layout .tp_size_per_dp == decode_tp_size :
229+ if self ._can_use_identity_plan () and (
230+ self .prefill_layout .tp_size_per_dp == decode_tp_size
231+ ):
220232 prefill_rank = (
221233 target_dp_group * self .prefill_layout .tp_size_per_dp + decode_tp_rank
222234 )
@@ -261,19 +273,25 @@ def plan_for_decode_rank(self, decode_rank: int) -> RankTransferPlan:
261273 fragments .setdefault (prefill_rank , []).append (fragment )
262274 continue
263275
264- decode_interval = self ._rank_interval (
265- decode_buffer .logical_size , decode_tp_size , decode_tp_rank
276+ decode_interval = self ._rank_interval_for_buffer (
277+ decode_buffer ,
278+ self .decode_layout ,
279+ decode_tp_rank ,
266280 )
281+ if decode_interval is None :
282+ continue
267283 for prefill_tp_rank in range (self .prefill_layout .tp_size_per_dp ):
268284 prefill_rank = (
269285 target_dp_group * self .prefill_layout .tp_size_per_dp
270286 + prefill_tp_rank
271287 )
272- prefill_interval = self ._rank_interval (
273- prefill_buffer . logical_size ,
274- self .prefill_layout . tp_size_per_dp ,
288+ prefill_interval = self ._rank_interval_for_buffer (
289+ prefill_buffer ,
290+ self .prefill_layout ,
275291 prefill_tp_rank ,
276292 )
293+ if prefill_interval is None :
294+ continue
277295 intersection = prefill_interval .intersect (decode_interval )
278296 if intersection is None :
279297 continue
@@ -341,11 +359,29 @@ def _validate_alignment(self) -> None:
341359 for buffer in buffers :
342360 if buffer .logical_axis == "replicated" :
343361 continue
344- if buffer .logical_size % layout .tp_size_per_dp != 0 :
362+ if layout .tp_size_per_dp % buffer .tp_replica_group_size != 0 :
363+ raise UnsupportedPDLayoutError (
364+ "tp replica group must divide TP size for "
365+ f"buffer_kind={ buffer .buffer_kind .value } : "
366+ f"tp_size_per_dp={ layout .tp_size_per_dp } , "
367+ f"tp_replica_group_size={ buffer .tp_replica_group_size } "
368+ )
369+ effective_tp_size = (
370+ layout .tp_size_per_dp // buffer .tp_replica_group_size
371+ )
372+ if buffer .logical_size % effective_tp_size != 0 :
345373 raise UnsupportedPDLayoutError (
346374 "non-aligned TP heterogeneous mapping for "
347375 f"buffer_kind={ buffer .buffer_kind .value } : logical_size="
348- f"{ buffer .logical_size } , tp_size_per_dp={ layout .tp_size_per_dp } "
376+ f"{ buffer .logical_size } , effective_tp_size={ effective_tp_size } "
377+ )
378+ item_units = buffer .item_stride_bytes // buffer .bytes_per_logical_unit
379+ required_units = buffer .logical_size // effective_tp_size
380+ if item_units < required_units :
381+ raise UnsupportedPDLayoutError (
382+ "buffer item is smaller than its logical shard for "
383+ f"buffer_kind={ buffer .buffer_kind .value } : item_units="
384+ f"{ item_units } , required_units={ required_units } "
349385 )
350386
351387 def _calc_source_fanout (self ) -> dict [int , int ]:
@@ -370,17 +406,21 @@ def _calc_source_fanout(self) -> dict[int, int]:
370406 intersected_prefill_ranks .add (prefill_rank )
371407 continue
372408
373- decode_interval = self ._rank_interval (
374- decode_buffer . logical_size ,
375- self .decode_layout . tp_size_per_dp ,
409+ decode_interval = self ._rank_interval_for_buffer (
410+ decode_buffer ,
411+ self .decode_layout ,
376412 decode_tp_rank ,
377413 )
414+ if decode_interval is None :
415+ continue
378416 for prefill_tp_rank in range (self .prefill_layout .tp_size_per_dp ):
379- prefill_interval = self ._rank_interval (
380- prefill_buffer . logical_size ,
381- self .prefill_layout . tp_size_per_dp ,
417+ prefill_interval = self ._rank_interval_for_buffer (
418+ prefill_buffer ,
419+ self .prefill_layout ,
382420 prefill_tp_rank ,
383421 )
422+ if prefill_interval is None :
423+ continue
384424 if prefill_interval .intersect (decode_interval ) is None :
385425 continue
386426 prefill_rank = (
@@ -392,12 +432,34 @@ def _calc_source_fanout(self) -> dict[int, int]:
392432 fanout [prefill_rank ] += 1
393433 return fanout
394434
435+ def _can_use_identity_plan (self ) -> bool :
436+ return all (
437+ prefill_buffer .tp_replica_group_size == 1
438+ and decode_buffer .tp_replica_group_size == 1
439+ for prefill_buffer , decode_buffer in zip (
440+ self .prefill_buffers , self .decode_buffers
441+ )
442+ )
443+
395444 @staticmethod
396445 def _rank_interval (logical_size : int , tp_size : int , tp_rank : int ) -> _Interval :
397446 local_size = logical_size // tp_size
398447 start = tp_rank * local_size
399448 return _Interval (start , start + local_size )
400449
450+ @staticmethod
451+ def _rank_interval_for_buffer (
452+ buffer : BufferLayout , layout : ParallelLayout , tp_rank : int
453+ ) -> _Interval | None :
454+ replica_group_size = buffer .tp_replica_group_size
455+ if tp_rank % replica_group_size != 0 :
456+ return None
457+ effective_tp_size = layout .tp_size_per_dp // replica_group_size
458+ effective_tp_rank = tp_rank // replica_group_size
459+ return PDTransferPlanner ._rank_interval (
460+ buffer .logical_size , effective_tp_size , effective_tp_rank
461+ )
462+
401463 @staticmethod
402464 def _replicated_source_tp_rank (
403465 prefill_tp_size : int , decode_tp_size : int , decode_tp_rank : int
0 commit comments