Skip to content

Commit ce9f2f6

Browse files
committed
fix(scheduler): skip mid-flight prefix publish for sliding-window models
The prior commit's mid-flight prefix publish regressed gpt-oss-120b GPQA-diamond (~0.71 -> 0.547, both NVIDIA + AMD): gpt-oss is non-hybrid and uses sliding-window attention, where sharing a prefix mid-flight (while the publishing request is still decoding) corrupts SWA prefix reuse. Full-attention prefix caching (ut-runtime-prefix-cache-e2e) and hybrid/MLA models were unaffected. Add has_sliding_window to SchedulerConfig, derived in event_loop.py from hf_config.sliding_window (mirroring ModelRunner's SWA detection). For SWA models the scheduler passes a null kv_prefix_cache to SchedulePrefillEvent/ScheduleDecodeEvent so InsertPrefixCache skips the mid-flight publish; the prefix is published only at FinishEvent -- the prior, known-correct behavior. Full-attention non-hybrid models keep the mid-flight reuse; hybrid (DeepSeek-V4) is unchanged. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
1 parent 8373813 commit ce9f2f6

6 files changed

Lines changed: 24 additions & 5 deletions

File tree

python/tokenspeed/runtime/engine/event_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def __init__(
179179
has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or (
180180
text_config is not None and hasattr(text_config, "mamba2_cache_params")
181181
)
182+
# Sliding-window-attention models must not publish their prefix mid-flight (they
183+
# publish only at FinishEvent); the SWA prefix-reuse path corrupts outputs
184+
# otherwise. Mirror ModelRunner's SWA detection (hf_config.sliding_window).
185+
has_sliding_window = getattr(hf_config, "sliding_window", None) is not None
182186

183187
model_executor_config = ModelExecutorConfig.from_server_args(
184188
server_args=server_args,
@@ -327,6 +331,7 @@ def __init__(
327331
else 1
328332
),
329333
disable_prefix_cache=not server_args.enable_prefix_caching,
334+
has_sliding_window=has_sliding_window,
330335
enable_mamba=has_mamba,
331336
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
332337
mamba_pool_total_chunks=mamba_pool_total_chunks,

python/tokenspeed/runtime/engine/scheduler_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def make_config(
6464
enable_kv_cache_events: bool = False,
6565
decode_input_tokens: int = 1,
6666
disable_prefix_cache: bool = False,
67+
has_sliding_window: bool = False,
6768
enable_mamba: bool = False,
6869
mamba_cache_chunk_size: int = 64,
6970
mamba_pool_total_chunks: int = 0,
@@ -93,6 +94,7 @@ def make_config(
9394
cfg.num_device_pages = num_device_pages
9495
cfg.decode_input_tokens = decode_input_tokens
9596
cfg.disable_prefix_cache = disable_prefix_cache
97+
cfg.has_sliding_window = has_sliding_window
9698
cfg.disable_l2_cache = disable_l2_cache
9799

98100
cfg.enable_mamba = enable_mamba

tokenspeed-scheduler/bindings/python_module.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) {
237237
.def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events)
238238
.def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode)
239239
.def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache)
240+
.def_rw("has_sliding_window", &tokenspeed::SchedulerConfig::has_sliding_window)
240241
.def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba)
241242
.def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size)
242243
.def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks)

tokenspeed-scheduler/csrc/fsm/forward_events.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid
108108
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
109109
std::int32_t page_size) {
110110
// Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally
111-
// track a Mamba checkpoint); plain models publish through the base KV prefix cache.
112-
// Either way the freshly-computed prefix becomes matchable by concurrent requests
113-
// now, rather than only when this request finishes (FinishEvent).
111+
// track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV
112+
// prefix cache, making the freshly-computed prefix matchable by concurrent requests
113+
// now rather than only at FinishEvent. A null kv_prefix_cache (passed by the scheduler
114+
// for sliding-window-attention models) disables this mid-flight publish, so SWA models
115+
// fall back to the finish-only publish whose prefix reuse is known-correct.
114116
KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache;
115117
if (kv == nullptr) return;
116118

tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,11 @@ std::optional<fsm::SchedulePrefillEvent> Scheduler::schedulePrefill(
191191
return {};
192192
}
193193

194+
// Sliding-window-attention models pass a null kv_prefix_cache so InsertPrefixCache
195+
// skips the mid-flight publish; their prefix is published only at FinishEvent.
194196
return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event,
195-
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_};
197+
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr,
198+
config_.has_sliding_window ? nullptr : &kv_prefix_cache_};
196199
}
197200

198201
std::optional<fsm::ScheduleDecodeEvent> Scheduler::scheduleDecode(Request* request,
@@ -217,8 +220,10 @@ std::optional<fsm::ScheduleDecodeEvent> Scheduler::scheduleDecode(Request* reque
217220
return {};
218221
}
219222

223+
// SWA models: skip mid-flight publish (see schedulePrefill) -- publish only at FinishEvent.
220224
return fsm::ScheduleDecodeEvent{config_.decode_input_tokens,
221-
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_};
225+
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr,
226+
config_.has_sliding_window ? nullptr : &kv_prefix_cache_};
222227
}
223228

224229
std::optional<fsm::ScheduleDecodeFromRetractedEvent> Scheduler::scheduleDecodeFromRetracted(

tokenspeed-scheduler/csrc/scheduler/types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ struct SchedulerConfig {
101101
Role role{Role::kFused};
102102

103103
bool disable_prefix_cache{false};
104+
// Sliding-window-attention models publish their prefix only at FinishEvent: the
105+
// mid-flight publish (prefill->decode) enables an SWA prefix-reuse path that
106+
// corrupts outputs, so the scheduler skips it for these models.
107+
bool has_sliding_window{false};
104108
bool enable_mamba{false};
105109
std::int32_t mamba_cache_chunk_size{64};
106110
std::int32_t mamba_pool_total_chunks{0};

0 commit comments

Comments
 (0)