From d5d3fc1579a7b11606c13aa87e22798024b4a57d Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 8 Jun 2026 16:59:41 +0000 Subject: [PATCH 1/7] fix(scheduler): mid-flight prefix publish for non-hybrid full-attention models Publish a request's prompt-prefix KV into the radix tree at the prefill->decode transition (not only at FinishEvent) for non-hybrid models, so concurrent same-prefix requests (RL rollouts, shared chat-template prefixes) reuse it. Rename InsertHybridCache->InsertPrefixCache; publish via the base KV prefix cache when there is no hybrid cache (hybrid path unchanged; Mamba checkpoint stays hybrid-only). Thread kv_prefix_cache_ into SchedulePrefillEvent/ScheduleDecodeEvent. Sliding-window-attention models (gpt-oss) are excluded: mid-flight SWA prefix reuse corrupts outputs (regressed gpt-oss GPQA 0.71->0.547). Add has_sliding_window to SchedulerConfig (from hf_config.sliding_window); SWA models pass a null kv_prefix_cache so the publish is skipped and they fall back to finish-only. deepseek-v4 (hybrid SWA) is unaffected. Signed-off-by: Qingyang Wu --- .../tokenspeed/runtime/engine/event_loop.py | 5 ++++ .../runtime/engine/scheduler_utils.py | 2 ++ .../bindings/python_module.cpp | 1 + .../csrc/fsm/forward_events.cpp | 21 +++++++++----- .../csrc/fsm/forward_events.h | 29 ++++++++++++++----- .../csrc/scheduler/operations/forward.cpp | 9 ++++-- tokenspeed-scheduler/csrc/scheduler/types.h | 4 +++ 7 files changed, 55 insertions(+), 16 deletions(-) diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index c3068c8b6..e656f86d4 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -179,6 +179,10 @@ def __init__( has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or ( text_config is not None and hasattr(text_config, "mamba2_cache_params") ) + # Sliding-window-attention models must not publish their prefix mid-flight (they + # publish only at FinishEvent); the SWA prefix-reuse path corrupts outputs + # otherwise. Mirror ModelRunner's SWA detection (hf_config.sliding_window). + has_sliding_window = getattr(hf_config, "sliding_window", None) is not None model_executor_config = ModelExecutorConfig.from_server_args( server_args=server_args, @@ -327,6 +331,7 @@ def __init__( else 1 ), disable_prefix_cache=not server_args.enable_prefix_caching, + has_sliding_window=has_sliding_window, enable_mamba=has_mamba, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index f80566871..949eefd0d 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -64,6 +64,7 @@ def make_config( enable_kv_cache_events: bool = False, decode_input_tokens: int = 1, disable_prefix_cache: bool = False, + has_sliding_window: bool = False, enable_mamba: bool = False, mamba_cache_chunk_size: int = 64, mamba_pool_total_chunks: int = 0, @@ -93,6 +94,7 @@ def make_config( cfg.num_device_pages = num_device_pages cfg.decode_input_tokens = decode_input_tokens cfg.disable_prefix_cache = disable_prefix_cache + cfg.has_sliding_window = has_sliding_window cfg.disable_l2_cache = disable_l2_cache cfg.enable_mamba = enable_mamba diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index eaa825b29..0460ddee5 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -237,6 +237,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events) .def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode) .def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache) + .def_rw("has_sliding_window", &tokenspeed::SchedulerConfig::has_sliding_window) .def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba) .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) .def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks) diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 49b5c45a7..40b7e2964 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -102,12 +102,19 @@ bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, s namespace tokenspeed::fsm { -void InsertHybridCache(HybridPrefixCache* hybrid_cache, +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, std::int32_t page_size, const std::vector* prefix_pages_override) { - if (hybrid_cache == nullptr) return; + // Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally + // track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV + // prefix cache, making the freshly-computed prefix matchable by concurrent requests + // now rather than only at FinishEvent. A null kv_prefix_cache (passed by the scheduler + // for sliding-window-attention models) disables this mid-flight publish, so SWA models + // fall back to the finish-only publish whose prefix reuse is known-correct. + KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache; + if (kv == nullptr) return; std::vector computed_prefix_pages; const std::vector* prefix_pages = prefix_pages_override; @@ -125,10 +132,10 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, } OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = hybrid_cache->GetKVPrefixCache().Insert(full_paged_tokens, *prefix_pages, - std::move(pages_to_insert)); + auto insert_result = kv->Insert(full_paged_tokens, *prefix_pages, std::move(pages_to_insert)); - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { + // Mamba checkpoint publication is hybrid-only. + if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) { hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint()); } else { @@ -218,7 +225,7 @@ std::variant SchedulePrefillEvent::operator()(Prefillin if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); // Allocate KV pages for the new chunk local_kv_allocator->Acquire(tokens_this_round_); @@ -268,7 +275,7 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); // Allocate fresh checkpoint for decode-phase mamba state tracking if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 095ff7de4..6308150c7 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -53,7 +53,14 @@ namespace tokenspeed::fsm { struct PrefetchDone; struct Prefetching; -void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache, +// Publish a request's freshly-computed prefix into the device radix tree *mid-flight* +// (during prefill / at the prefill->decode transition) so other in-flight requests that +// share the prefix can reuse it -- instead of only after the request finishes +// (FinishEvent). Works for both the plain KV prefix cache (kv_prefix_cache) and, when +// present, the hybrid cache (which additionally publishes the Mamba checkpoint). The +// published node is pinned via the request's device_node_ref so it is not evicted while +// the request is still using it. +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, @@ -107,10 +114,11 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + hybrid_prefix_cache_(hybrid_prefix_cache), + kv_prefix_cache_(kv_prefix_cache) {} // Returns PrefillDone (last chunk) or Prefilling (more chunks remain). std::variant operator()(Prefilling&& state); @@ -119,13 +127,17 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { std::int32_t tokens_this_round_{}; std::int32_t reserve_num_tokens_in_next_schedule_event_{}; HybridPrefixCache* hybrid_prefix_cache_{}; + KVPrefixCache* kv_prefix_cache_{}; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr) - : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {} + ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr, + KVPrefixCache* kv_prefix_cache = nullptr) + : decode_input_tokens_(decode_input_tokens), + hybrid_prefix_cache_(hybrid_prefix_cache), + kv_prefix_cache_(kv_prefix_cache) {} Decoding operator()(PrefillDone&& state); Decoding operator()(Decoding&& state); @@ -133,6 +145,7 @@ struct ScheduleDecodeEvent : InvalidTransitionHandler { private: std::int32_t decode_input_tokens_; HybridPrefixCache* hybrid_prefix_cache_{}; + KVPrefixCache* kv_prefix_cache_{}; }; struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler { @@ -347,8 +360,10 @@ struct ExtendResultEvent : InvalidTransitionHandler { auto host_node_ref = std::move(state).TakeHostNodeRef(); if (new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) { - InsertHybridCache(hybrid_prefix_cache_, full_paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), chunk_begin, + // Hybrid (MTP) path: kv_prefix_cache is unused for hybrid models (they publish + // via hybrid_cache), so pass nullptr. + InsertPrefixCache(/*kv_prefix_cache=*/nullptr, hybrid_prefix_cache_, full_paged_tokens, device_node_ref, + local_kv_allocator.get(), local_mamba_allocator.get(), chunk_begin, static_cast(result_tokens_.size()), page_size, &prefix_pages); hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node()); } diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index fd48962fe..b640ab3e3 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -191,8 +191,11 @@ std::optional Scheduler::schedulePrefill( return {}; } + // Sliding-window-attention models pass a null kv_prefix_cache so InsertPrefixCache + // skips the mid-flight publish; their prefix is published only at FinishEvent. return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + config_.has_sliding_window ? nullptr : &kv_prefix_cache_}; } std::optional Scheduler::scheduleDecode(Request* request, @@ -217,8 +220,10 @@ std::optional Scheduler::scheduleDecode(Request* reque return {}; } + // SWA models: skip mid-flight publish (see schedulePrefill) -- publish only at FinishEvent. return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + config_.has_sliding_window ? nullptr : &kv_prefix_cache_}; } std::optional Scheduler::scheduleDecodeFromRetracted( diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index 10f723b05..ebfad7c30 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -101,6 +101,10 @@ struct SchedulerConfig { Role role{Role::kFused}; bool disable_prefix_cache{false}; + // Sliding-window-attention models publish their prefix only at FinishEvent: the + // mid-flight publish (prefill->decode) enables an SWA prefix-reuse path that + // corrupts outputs, so the scheduler skips it for these models. + bool has_sliding_window{false}; bool enable_mamba{false}; std::int32_t mamba_cache_chunk_size{64}; std::int32_t mamba_pool_total_chunks{0}; From bee09204332321003ba774a2b5b15d216015e76d Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 8 Jun 2026 22:14:29 +0000 Subject: [PATCH 2/7] fix(scheduler): allow safe SWA mid-flight prefix publish Signed-off-by: Qingyang Wu --- .../tokenspeed/runtime/engine/event_loop.py | 18 ++- .../runtime/engine/scheduler_utils.py | 2 + .../bindings/python_module.cpp | 1 + .../csrc/fsm/forward_events.cpp | 50 ++++-- .../csrc/fsm/forward_events.h | 74 ++++++--- .../csrc/scheduler/operations/forward.cpp | 40 ++++- .../csrc/scheduler/outside_event_handler.cpp | 4 +- .../csrc/scheduler/scheduler.h | 2 + tokenspeed-scheduler/csrc/scheduler/types.h | 9 +- .../tests/cpp/test_paged_cache_replay.cpp | 144 ++++++++++++++++++ 10 files changed, 296 insertions(+), 48 deletions(-) diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index e656f86d4..246cac259 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -179,10 +179,19 @@ def __init__( has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or ( text_config is not None and hasattr(text_config, "mamba2_cache_params") ) - # Sliding-window-attention models must not publish their prefix mid-flight (they - # publish only at FinishEvent); the SWA prefix-reuse path corrupts outputs - # otherwise. Mirror ModelRunner's SWA detection (hf_config.sliding_window). - has_sliding_window = getattr(hf_config, "sliding_window", None) is not None + # Mirror ModelRunner's SWA detection (hf_config.sliding_window). Plain + # SWA mid-flight publish is capped to this window; hybrid paged-cache + # models use their windowed adjunct snapshots. gpt-oss stores an + # inclusive HF window and converts it to TokenSpeed's exclusive + # attention window inside the model. + sliding_window = getattr(hf_config, "sliding_window", None) + if ( + getattr(hf_config, "model_type", None) == "gpt_oss" + and sliding_window is not None + ): + sliding_window = max(0, int(sliding_window) - 1) + has_sliding_window = sliding_window is not None + sliding_window_size = int(sliding_window) if sliding_window is not None else 0 model_executor_config = ModelExecutorConfig.from_server_args( server_args=server_args, @@ -332,6 +341,7 @@ def __init__( ), disable_prefix_cache=not server_args.enable_prefix_caching, has_sliding_window=has_sliding_window, + sliding_window_size=sliding_window_size, enable_mamba=has_mamba, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 949eefd0d..f5dadc9d6 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -65,6 +65,7 @@ def make_config( decode_input_tokens: int = 1, disable_prefix_cache: bool = False, has_sliding_window: bool = False, + sliding_window_size: int = 0, enable_mamba: bool = False, mamba_cache_chunk_size: int = 64, mamba_pool_total_chunks: int = 0, @@ -95,6 +96,7 @@ def make_config( cfg.decode_input_tokens = decode_input_tokens cfg.disable_prefix_cache = disable_prefix_cache cfg.has_sliding_window = has_sliding_window + cfg.sliding_window_size = sliding_window_size cfg.disable_l2_cache = disable_l2_cache cfg.enable_mamba = enable_mamba diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index 0460ddee5..967506424 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -238,6 +238,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode) .def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache) .def_rw("has_sliding_window", &tokenspeed::SchedulerConfig::has_sliding_window) + .def_rw("sliding_window_size", &tokenspeed::SchedulerConfig::sliding_window_size) .def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba) .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) .def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks) diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 40b7e2964..dbab51f74 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -106,16 +106,43 @@ void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size, const std::vector* prefix_pages_override) { + std::int32_t page_size, const std::vector* prefix_pages_override, + bool enable_midflight_publish, std::int32_t max_publish_tokens) { // Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally // track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV // prefix cache, making the freshly-computed prefix matchable by concurrent requests - // now rather than only at FinishEvent. A null kv_prefix_cache (passed by the scheduler - // for sliding-window-attention models) disables this mid-flight publish, so SWA models - // fall back to the finish-only publish whose prefix reuse is known-correct. + // now rather than only at FinishEvent. Callers disable this explicitly for + // SWA configs that cannot restore the windowed state required for reuse. + auto detach_checkpoint = [&]() { + if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { + local_mamba_allocator->DetachCheckpoint(); + } + }; + if (!enable_midflight_publish) { + detach_checkpoint(); + return; + } KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache; if (kv == nullptr) return; + std::vector> capped_paged_tokens; + const std::vector>* publish_paged_tokens = &full_paged_tokens; + if (max_publish_tokens > 0) { + if (page_size <= 0) { + detach_checkpoint(); + return; + } + const std::int32_t max_publish_pages = max_publish_tokens / page_size; + if (max_publish_pages <= 0) { + detach_checkpoint(); + return; + } + const auto publish_pages = + std::min(max_publish_pages, static_cast(full_paged_tokens.size())); + capped_paged_tokens.assign(full_paged_tokens.begin(), full_paged_tokens.begin() + publish_pages); + publish_paged_tokens = &capped_paged_tokens; + } + std::vector computed_prefix_pages; const std::vector* prefix_pages = prefix_pages_override; if (prefix_pages == nullptr) { @@ -123,16 +150,15 @@ void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid prefix_pages = &computed_prefix_pages; } std::int32_t new_page_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages->size()); + static_cast(publish_paged_tokens->size()) - static_cast(prefix_pages->size()); if (new_page_count <= 0) { - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { - local_mamba_allocator->DetachCheckpoint(); - } + detach_checkpoint(); return; } OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = kv->Insert(full_paged_tokens, *prefix_pages, std::move(pages_to_insert)); + auto insert_result = + kv->Insert(*publish_paged_tokens, *prefix_pages, std::move(pages_to_insert)); // Mamba checkpoint publication is hybrid-only. if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { @@ -226,7 +252,8 @@ std::variant SchedulePrefillEvent::operator()(Prefillin paged_tokens.resize(end_of_window_pages); } InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); // Allocate KV pages for the new chunk local_kv_allocator->Acquire(tokens_this_round_); @@ -276,7 +303,8 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { paged_tokens.resize(end_of_window_pages); } InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); // Allocate fresh checkpoint for decode-phase mamba state tracking if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { if (!local_mamba_allocator->AllocateCheckpoint()) { diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 6308150c7..50a9f1f72 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -54,17 +54,18 @@ struct PrefetchDone; struct Prefetching; // Publish a request's freshly-computed prefix into the device radix tree *mid-flight* -// (during prefill / at the prefill->decode transition) so other in-flight requests that -// share the prefix can reuse it -- instead of only after the request finishes -// (FinishEvent). Works for both the plain KV prefix cache (kv_prefix_cache) and, when -// present, the hybrid cache (which additionally publishes the Mamba checkpoint). The -// published node is pinned via the request's device_node_ref so it is not evicted while -// the request is still using it. +// (during prefill, at the prefill->decode transition, or after accepted decode +// tokens) so other in-flight requests that share the prefix can reuse it -- instead +// of only after the request finishes (FinishEvent). Works for both the plain KV prefix +// cache (kv_prefix_cache) and, when present, the hybrid cache (which additionally +// publishes the Mamba checkpoint). The published node is pinned via the request's +// device_node_ref so it is not evicted while the request is still using it. void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size, const std::vector* prefix_pages_override = nullptr); + std::int32_t page_size, const std::vector* prefix_pages_override = nullptr, + bool enable_midflight_publish = true, std::int32_t max_publish_tokens = 0); struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); @@ -114,11 +115,14 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr) + HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr, + bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), hybrid_prefix_cache_(hybrid_prefix_cache), - kv_prefix_cache_(kv_prefix_cache) {} + kv_prefix_cache_(kv_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} // Returns PrefillDone (last chunk) or Prefilling (more chunks remain). std::variant operator()(Prefilling&& state); @@ -128,16 +132,21 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { std::int32_t reserve_num_tokens_in_next_schedule_event_{}; HybridPrefixCache* hybrid_prefix_cache_{}; KVPrefixCache* kv_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr, - KVPrefixCache* kv_prefix_cache = nullptr) + KVPrefixCache* kv_prefix_cache = nullptr, bool enable_midflight_publish = true, + std::int32_t max_midflight_publish_tokens = 0) : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache), - kv_prefix_cache_(kv_prefix_cache) {} + kv_prefix_cache_(kv_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} Decoding operator()(PrefillDone&& state); Decoding operator()(Decoding&& state); @@ -146,6 +155,8 @@ struct ScheduleDecodeEvent : InvalidTransitionHandler { std::int32_t decode_input_tokens_; HybridPrefixCache* hybrid_prefix_cache_{}; KVPrefixCache* kv_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler { @@ -311,10 +322,14 @@ struct ExtendResultEvent : InvalidTransitionHandler { ExtendResultEvent() = delete; ExtendResultEvent(std::string request_id, std::vector result_tokens, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + KVPrefixCache* kv_prefix_cache = nullptr, HybridPrefixCache* hybrid_prefix_cache = nullptr, + bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0) : request_id_(std::move(request_id)), result_tokens_(std::move(result_tokens)), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + kv_prefix_cache_(kv_prefix_cache), + hybrid_prefix_cache_(hybrid_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} public: template @@ -331,7 +346,7 @@ struct ExtendResultEvent : InvalidTransitionHandler { const std::int32_t page_size = state.GetPageSize(); const std::int32_t reserve = state.GetReserveNumTokensInNextScheduleEvent(); - if (hybrid_prefix_cache_ == nullptr) { + if (kv_prefix_cache_ == nullptr && hybrid_prefix_cache_ == nullptr) { return std::move(state); } @@ -344,30 +359,38 @@ struct ExtendResultEvent : InvalidTransitionHandler { const std::int32_t new_publishable_pages = publishable_pages(accepted_token_size); if (new_publishable_pages <= old_publishable_pages) { - hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); + } return std::move(state); } const std::int32_t chunk_begin = accepted_token_size - static_cast(result_tokens_.size()); auto full_paged_tokens = state.GetFullPagedTokens(/*except_last=*/true); std::vector prefix_pages = DevicePagesFromRoot(state.GetDeviceNode()); - const std::int32_t new_page_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + std::int32_t publish_page_count = static_cast(full_paged_tokens.size()); + if (max_midflight_publish_tokens_ > 0 && page_size > 0) { + publish_page_count = std::min(publish_page_count, max_midflight_publish_tokens_ / page_size); + } + const std::int32_t new_page_count = publish_page_count - static_cast(prefix_pages.size()); auto local_kv_allocator = std::move(state).TakeLocalKVAllocator(); auto local_mamba_allocator = std::move(state).TakeLocalMambaAllocator(); auto device_node_ref = std::move(state).TakeDeviceNodeRef(); auto host_node_ref = std::move(state).TakeHostNodeRef(); - if (new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) { - // Hybrid (MTP) path: kv_prefix_cache is unused for hybrid models (they publish - // via hybrid_cache), so pass nullptr. - InsertPrefixCache(/*kv_prefix_cache=*/nullptr, hybrid_prefix_cache_, full_paged_tokens, device_node_ref, + if (enable_midflight_publish_ && new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) { + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, full_paged_tokens, device_node_ref, local_kv_allocator.get(), local_mamba_allocator.get(), chunk_begin, - static_cast(result_tokens_.size()), page_size, &prefix_pages); - hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node()); + static_cast(result_tokens_.size()), page_size, &prefix_pages, + enable_midflight_publish_, max_midflight_publish_tokens_); + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node()); + } + } + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); } - hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); return Decoding{token_container, page_size, @@ -389,7 +412,10 @@ struct ExtendResultEvent : InvalidTransitionHandler { private: std::string request_id_; std::vector result_tokens_; + KVPrefixCache* kv_prefix_cache_{}; HybridPrefixCache* hybrid_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; } // namespace tokenspeed::fsm diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index b640ab3e3..41a973a7f 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -191,11 +191,12 @@ std::optional Scheduler::schedulePrefill( return {}; } - // Sliding-window-attention models pass a null kv_prefix_cache so InsertPrefixCache - // skips the mid-flight publish; their prefix is published only at FinishEvent. + // Plain SWA models publish only up to the attention window; hybrid + // paged-cache SWA models publish when their adjunct carries window state. return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, - config_.has_sliding_window ? nullptr : &kv_prefix_cache_}; + &kv_prefix_cache_, enableMidflightPrefixPublish(), + maxMidflightPrefixPublishTokens()}; } std::optional Scheduler::scheduleDecode(Request* request, @@ -220,10 +221,39 @@ std::optional Scheduler::scheduleDecode(Request* reque return {}; } - // SWA models: skip mid-flight publish (see schedulePrefill) -- publish only at FinishEvent. return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, - config_.has_sliding_window ? nullptr : &kv_prefix_cache_}; + &kv_prefix_cache_, enableMidflightPrefixPublish(), + maxMidflightPrefixPublishTokens()}; +} + +bool Scheduler::enableMidflightPrefixPublish() const { + if (!config_.has_sliding_window) { + return true; + } + if (!hybrid_prefix_cache_ && config_.sliding_window_size > 0) { + return true; + } + if (!hybrid_prefix_cache_ || !hybrid_prefix_cache_->HasPagedCacheAdjunct()) { + return false; + } + for (const auto& group : config_.paged_cache_groups) { + if (group.family == PagedCacheGroupFamily::State && + group.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + return true; + } + } + return false; +} + +std::int32_t Scheduler::maxMidflightPrefixPublishTokens() const { + if (!config_.has_sliding_window) { + return 0; + } + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasPagedCacheAdjunct()) { + return 0; + } + return std::max(0, config_.sliding_window_size); } std::optional Scheduler::scheduleDecodeFromRetracted( diff --git a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp index 02e520c45..9b2ab7a91 100644 --- a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp @@ -127,7 +127,9 @@ void Scheduler::handleEvent(const forward::UpdateReserveNumTokens& event) { void Scheduler::handleEvent(const forward::ExtendResult& event) { if (auto req = find_request(event.request_id)) { req->Apply(fsm::ExtendResultEvent{event.request_id, event.tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + &kv_prefix_cache_, + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}); } } diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.h b/tokenspeed-scheduler/csrc/scheduler/scheduler.h index c36c3a413..492eebad9 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.h +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.h @@ -103,6 +103,8 @@ class Scheduler { std::map& simulated_free); std::optional scheduleDecode(Request* request, std::map& simulated_free); + bool enableMidflightPrefixPublish() const; + std::int32_t maxMidflightPrefixPublishTokens() const; std::optional scheduleDecodeFromRetracted( Request* request, std::map& simulated_free); std::optional scheduleRetract(Request* request); diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index ebfad7c30..fd8a342ab 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -101,10 +101,13 @@ struct SchedulerConfig { Role role{Role::kFused}; bool disable_prefix_cache{false}; - // Sliding-window-attention models publish their prefix only at FinishEvent: the - // mid-flight publish (prefill->decode) enables an SWA prefix-reuse path that - // corrupts outputs, so the scheduler skips it for these models. + // Plain sliding-window-attention models publish their prefix only at FinishEvent. + // Hybrid paged-cache SWA models may publish mid-flight when their adjunct can + // restore the windowed state needed for prefix reuse. bool has_sliding_window{false}; + // Token length of the SWA window. Plain SWA mid-flight publish is capped to + // this many tokens so reuse stays in the full-history-equivalent region. + std::int32_t sliding_window_size{0}; bool enable_mamba{false}; std::int32_t mamba_cache_chunk_size{64}; std::int32_t mamba_pool_total_chunks{0}; diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp index ede2e3f35..2892d9ae4 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp @@ -124,6 +124,60 @@ class PagedCacheDecodePublishTest : public SchedulerTestSuite { } }; +class PagedCacheDecodePublishPlainSwaTest : public SchedulerTestSuite { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = SchedulerTestSuite::MakeConfig(); + cfg.page_size = 1; + cfg.device_allocator.total_pages = 64; + cfg.host_allocator.total_pages = 64; + cfg.max_scheduled_tokens = 64; + cfg.max_batch_size = 8; + cfg.decode_input_tokens = 4; + cfg.enable_l3_storage = false; + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + return cfg; + } + + static const FlatForwardOperation* GetForwardOp(const ExecutionPlan& plan) { + for (const auto& op : plan.Operations()) { + if (auto* f = std::get_if(&op)) return f; + } + return nullptr; + } +}; + +class PagedCacheDecodePublishHybridHistorySwaTest : public PagedCacheDecodePublishTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishTest::MakeConfig(); + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + return cfg; + } +}; + +class PagedCacheDecodePublishHybridSwaTest : public PagedCacheDecodePublishTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishTest::MakeConfig(); + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + + PagedCacheGroupConfig state{}; + state.group_id = "swa_state"; + state.rows_per_page = 1; + state.entry_stride_tokens = 1; + state.total_pages = 64; + state.retention = PagedCacheGroupConfig::Retention::SlidingWindow; + state.sliding_window_tokens = 4; + state.family = PagedCacheGroupFamily::State; + cfg.paged_cache_groups.push_back(state); + return cfg; + } +}; + class PagedCacheTerminalContinuationTest : public ::testing::Test { protected: static constexpr std::int32_t kPageSize = 64; @@ -446,6 +500,96 @@ TEST_F(PagedCacheDecodePublishTest, ContinuingDecodePublishesAcceptedPagesOnly) EXPECT_EQ(prefix_by_request.at("probe_tail"), 4); } +TEST_F(PagedCacheDecodePublishPlainSwaTest, MidflightPublishIsCappedToSlidingWindow) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + +TEST_F(PagedCacheDecodePublishPlainSwaTest, DecodeResultPublishesWindowCappedPrefix) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3, 4, 5, 6}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + +TEST_F(PagedCacheDecodePublishHybridHistorySwaTest, MidflightPublishRequiresWindowState) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 0); +} + +TEST_F(PagedCacheDecodePublishHybridSwaTest, ContinuingDecodePublishesWindowedHybridPrefixes) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {4, 5}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + TEST_F(PagedCacheTerminalMixedSchedulerTest, MixedPrefillDecodePagedTablesCoverScheduledTokens) { std::vector decode_ids; for (int i = 0; i < 5; ++i) { From 4deb07e4da0961cba6a95079725bbc4dae5f5480 Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 8 Jun 2026 22:44:47 +0000 Subject: [PATCH 3/7] style(scheduler): format SWA publish changes Signed-off-by: Qingyang Wu --- .../csrc/scheduler/operations/forward.cpp | 11 ++++++----- .../csrc/scheduler/outside_event_handler.cpp | 3 +-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index 41a973a7f..aeed461bb 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -193,9 +193,11 @@ std::optional Scheduler::schedulePrefill( // Plain SWA models publish only up to the attention window; hybrid // paged-cache SWA models publish when their adjunct carries window state. - return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, + return fsm::SchedulePrefillEvent{tokens_this_round, + reserve_num_tokens_in_next_schedule_event, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, - &kv_prefix_cache_, enableMidflightPrefixPublish(), + &kv_prefix_cache_, + enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}; } @@ -222,9 +224,8 @@ std::optional Scheduler::scheduleDecode(Request* reque } return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, - &kv_prefix_cache_, enableMidflightPrefixPublish(), - maxMidflightPrefixPublishTokens()}; + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_, + enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}; } bool Scheduler::enableMidflightPrefixPublish() const { diff --git a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp index 9b2ab7a91..408561e73 100644 --- a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp @@ -126,8 +126,7 @@ void Scheduler::handleEvent(const forward::UpdateReserveNumTokens& event) { } void Scheduler::handleEvent(const forward::ExtendResult& event) { if (auto req = find_request(event.request_id)) { - req->Apply(fsm::ExtendResultEvent{event.request_id, event.tokens, - &kv_prefix_cache_, + req->Apply(fsm::ExtendResultEvent{event.request_id, event.tokens, &kv_prefix_cache_, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}); } From 0bc9c7b6de2ce6c25ad5696c5346b7bd30e76e1c Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 8 Jun 2026 23:26:58 +0000 Subject: [PATCH 4/7] fix(runtime): avoid ragged MHA prefix split with sinks Signed-off-by: Qingyang Wu --- .../runtime/layers/attention/backends/mha.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tokenspeed/runtime/layers/attention/backends/mha.py b/python/tokenspeed/runtime/layers/attention/backends/mha.py index 778951072..767665649 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/mha.py +++ b/python/tokenspeed/runtime/layers/attention/backends/mha.py @@ -345,7 +345,13 @@ def forward_extend( metadata = self.forward_prefill_metadata if metadata.max_extend_prefix_len > 0: - if self.mha_extend_mode == "ragged": + sinks = kwargs.get("sinks") + # gpt-oss uses attention sinks. The ragged split path computes the + # cached-prefix and new-token attention states separately before + # merging; keep sink-attention requests on the unified KV-cache path + # so prefix-cache hits preserve the same numerical path as ordinary + # extend. + if self.mha_extend_mode == "ragged" and sinks is None: return self._forward_extend_split( q, k, @@ -355,7 +361,7 @@ def forward_extend( token_to_kv_pool, metadata, save_kv_cache, - kwargs.get("sinks"), + sinks, ) else: return self._forward_extend( @@ -367,7 +373,7 @@ def forward_extend( token_to_kv_pool, metadata, save_kv_cache, - kwargs.get("sinks"), + sinks, ) return self._forward_prefill( q, From c08c1a8c7720567d2e6223218008567ca54c8c9f Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 8 Jun 2026 23:32:47 +0000 Subject: [PATCH 5/7] fix(runtime): preserve midflight prefixes for mixed SWA models Signed-off-by: Qingyang Wu --- python/tokenspeed/runtime/engine/event_loop.py | 17 ++++------------- .../runtime/engine/scheduler_utils.py | 13 +++++++++++++ .../layers/attention/kv_cache/deepseek_v4.py | 2 +- .../csrc/scheduler/operations/forward.cpp | 14 +++++++++++--- tokenspeed-scheduler/csrc/scheduler/types.h | 6 +++--- .../tests/cpp/test_paged_cache_replay.cpp | 1 + 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index 246cac259..cbaa76cc5 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -53,6 +53,7 @@ pool_to_paged_cache_groups, pool_to_prefix_cache_adjunct_spec, pop_common_cache_event_payloads, + scheduler_sliding_window_args, should_use_overlap_schedule, ) from tokenspeed.runtime.execution.distributed_initializer import ( @@ -179,19 +180,9 @@ def __init__( has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or ( text_config is not None and hasattr(text_config, "mamba2_cache_params") ) - # Mirror ModelRunner's SWA detection (hf_config.sliding_window). Plain - # SWA mid-flight publish is capped to this window; hybrid paged-cache - # models use their windowed adjunct snapshots. gpt-oss stores an - # inclusive HF window and converts it to TokenSpeed's exclusive - # attention window inside the model. - sliding_window = getattr(hf_config, "sliding_window", None) - if ( - getattr(hf_config, "model_type", None) == "gpt_oss" - and sliding_window is not None - ): - sliding_window = max(0, int(sliding_window) - 1) - has_sliding_window = sliding_window is not None - sliding_window_size = int(sliding_window) if sliding_window is not None else 0 + has_sliding_window, sliding_window_size = scheduler_sliding_window_args( + hf_config + ) model_executor_config = ModelExecutorConfig.from_server_args( server_args=server_args, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index f5dadc9d6..7b4ca0444 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -44,6 +44,19 @@ _TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"} +def scheduler_sliding_window_args(hf_config: Any) -> tuple[bool, int]: + sliding_window = getattr(hf_config, "sliding_window", None) + if sliding_window is None: + return False, 0 + if getattr(hf_config, "model_type", None) == "gpt_oss": + # gpt-oss mixes full-attention and sliding-attention layers in a plain + # MHA KV pool. The scheduler must publish full prefix nodes so the full + # layers can replay all cached KV; sliding layers apply their window in + # the attention kernel. + return False, 0 + return True, int(sliding_window) + + def make_spec(rid: str, tokens: list[int]) -> RequestSpec: spec = RequestSpec() spec.request_id = rid diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index e59882531..0b32aedbe 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -964,7 +964,7 @@ def prefix_cache_required_group_ids(self) -> tuple[str, ...]: return tuple( str(spec.group_id) for spec in self.paged_cache_group_specs - if spec.family == "history" + if spec.family in {"history", "state"} ) def bind_paged_cache_scheduler(self, scheduler: object) -> None: diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index aeed461bb..960e803af 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -238,10 +238,18 @@ bool Scheduler::enableMidflightPrefixPublish() const { if (!hybrid_prefix_cache_ || !hybrid_prefix_cache_->HasPagedCacheAdjunct()) { return false; } + if (!config_.prefix_cache_adjunct.has_value()) { + return false; + } for (const auto& group : config_.paged_cache_groups) { - if (group.family == PagedCacheGroupFamily::State && - group.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { - return true; + if (group.family != PagedCacheGroupFamily::State || + group.retention != PagedCacheGroupConfig::Retention::SlidingWindow) { + continue; + } + for (const auto& required_group : config_.prefix_cache_adjunct->required_groups) { + if (required_group == group.group_id) { + return true; + } } } return false; diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index fd8a342ab..e549e3944 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -101,9 +101,9 @@ struct SchedulerConfig { Role role{Role::kFused}; bool disable_prefix_cache{false}; - // Plain sliding-window-attention models publish their prefix only at FinishEvent. - // Hybrid paged-cache SWA models may publish mid-flight when their adjunct can - // restore the windowed state needed for prefix reuse. + // Plain sliding-window-attention models cap mid-flight prefix publish to the + // full-history-equivalent prefix region. Mixed-layer models that need full + // history should leave this false and rely on per-layer attention windows. bool has_sliding_window{false}; // Token length of the SWA window. Plain SWA mid-flight publish is capped to // this many tokens so reuse stays in the full-history-equivalent region. diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp index 2892d9ae4..0282d060d 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp @@ -174,6 +174,7 @@ class PagedCacheDecodePublishHybridSwaTest : public PagedCacheDecodePublishTest state.sliding_window_tokens = 4; state.family = PagedCacheGroupFamily::State; cfg.paged_cache_groups.push_back(state); + cfg.prefix_cache_adjunct->required_groups.push_back("swa_state"); return cfg; } }; From c9809f7e83d0c72f0164abfd92a99e8b5d0bc78e Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Tue, 9 Jun 2026 00:05:03 +0000 Subject: [PATCH 6/7] fix(runtime): disable gpt-oss midflight prefix publish Signed-off-by: Qingyang Wu --- .../runtime/engine/scheduler_utils.py | 11 ++++--- .../tests/cpp/test_paged_cache_replay.cpp | 32 +++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 7b4ca0444..9a6414fa9 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -49,11 +49,12 @@ def scheduler_sliding_window_args(hf_config: Any) -> tuple[bool, int]: if sliding_window is None: return False, 0 if getattr(hf_config, "model_type", None) == "gpt_oss": - # gpt-oss mixes full-attention and sliding-attention layers in a plain - # MHA KV pool. The scheduler must publish full prefix nodes so the full - # layers can replay all cached KV; sliding layers apply their window in - # the attention kernel. - return False, 0 + # gpt-oss mixes full-attention and sliding-attention layers with + # attention sinks in a plain MHA KV pool. Prefix hits remain valid after + # a request finishes, but the current mid-flight publish paths are not + # numerically equivalent for this mix, so use the SWA guard with no safe + # plain publish window. + return True, 0 return True, int(sliding_window) diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp index 0282d060d..d1e033bf9 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp @@ -148,6 +148,15 @@ class PagedCacheDecodePublishPlainSwaTest : public SchedulerTestSuite { } }; +class PagedCacheDecodePublishPlainSwaNoWindowTest : public PagedCacheDecodePublishPlainSwaTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishPlainSwaTest::MakeConfig(); + cfg.sliding_window_size = 0; + return cfg; + } +}; + class PagedCacheDecodePublishHybridHistorySwaTest : public PagedCacheDecodePublishTest { protected: SchedulerConfig MakeConfig() override { @@ -545,6 +554,29 @@ TEST_F(PagedCacheDecodePublishPlainSwaTest, DecodeResultPublishesWindowCappedPre EXPECT_EQ(prefix_by_request.at("hit4"), 4); } +TEST_F(PagedCacheDecodePublishPlainSwaNoWindowTest, MidflightPublishIsDisabledWithoutSafeWindow) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3, 4, 5, 6}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 0); +} + TEST_F(PagedCacheDecodePublishHybridHistorySwaTest, MidflightPublishRequiresWindowState) { Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); From bd9ad99cc93fdfefe45184c8e7efdea66d9c3614 Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Tue, 9 Jun 2026 05:25:54 +0000 Subject: [PATCH 7/7] fix(runtime): disable Kimi K2.5 midflight prefix publish Signed-off-by: Qingyang Wu --- .../tokenspeed/runtime/engine/scheduler_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 9a6414fa9..6902d8a1b 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -45,16 +45,23 @@ def scheduler_sliding_window_args(hf_config: Any) -> tuple[bool, int]: - sliding_window = getattr(hf_config, "sliding_window", None) - if sliding_window is None: - return False, 0 - if getattr(hf_config, "model_type", None) == "gpt_oss": + model_type = getattr(hf_config, "model_type", None) + if model_type == "gpt_oss": # gpt-oss mixes full-attention and sliding-attention layers with # attention sinks in a plain MHA KV pool. Prefix hits remain valid after # a request finishes, but the current mid-flight publish paths are not # numerically equivalent for this mix, so use the SWA guard with no safe # plain publish window. return True, 0 + if model_type == "kimi_k25": + # Kimi-K2.5 evals use long EAGLE3 speculative decode. Terminal prefix + # hits remain valid, but mid-flight decode publication currently affects + # deterministic quality for this path, so keep reuse to completed + # requests. + return True, 0 + sliding_window = getattr(hf_config, "sliding_window", None) + if sliding_window is None: + return False, 0 return True, int(sliding_window)