diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index c3068c8b6..cbaa76cc5 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -53,6 +53,7 @@ pool_to_paged_cache_groups, pool_to_prefix_cache_adjunct_spec, pop_common_cache_event_payloads, + scheduler_sliding_window_args, should_use_overlap_schedule, ) from tokenspeed.runtime.execution.distributed_initializer import ( @@ -179,6 +180,9 @@ def __init__( has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or ( text_config is not None and hasattr(text_config, "mamba2_cache_params") ) + has_sliding_window, sliding_window_size = scheduler_sliding_window_args( + hf_config + ) model_executor_config = ModelExecutorConfig.from_server_args( server_args=server_args, @@ -327,6 +331,8 @@ def __init__( else 1 ), disable_prefix_cache=not server_args.enable_prefix_caching, + has_sliding_window=has_sliding_window, + sliding_window_size=sliding_window_size, enable_mamba=has_mamba, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index f80566871..6902d8a1b 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -44,6 +44,27 @@ _TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"} +def scheduler_sliding_window_args(hf_config: Any) -> tuple[bool, int]: + model_type = getattr(hf_config, "model_type", None) + if model_type == "gpt_oss": + # gpt-oss mixes full-attention and sliding-attention layers with + # attention sinks in a plain MHA KV pool. Prefix hits remain valid after + # a request finishes, but the current mid-flight publish paths are not + # numerically equivalent for this mix, so use the SWA guard with no safe + # plain publish window. + return True, 0 + if model_type == "kimi_k25": + # Kimi-K2.5 evals use long EAGLE3 speculative decode. Terminal prefix + # hits remain valid, but mid-flight decode publication currently affects + # deterministic quality for this path, so keep reuse to completed + # requests. + return True, 0 + sliding_window = getattr(hf_config, "sliding_window", None) + if sliding_window is None: + return False, 0 + return True, int(sliding_window) + + def make_spec(rid: str, tokens: list[int]) -> RequestSpec: spec = RequestSpec() spec.request_id = rid @@ -64,6 +85,8 @@ def make_config( enable_kv_cache_events: bool = False, decode_input_tokens: int = 1, disable_prefix_cache: bool = False, + has_sliding_window: bool = False, + sliding_window_size: int = 0, enable_mamba: bool = False, mamba_cache_chunk_size: int = 64, mamba_pool_total_chunks: int = 0, @@ -93,6 +116,8 @@ def make_config( cfg.num_device_pages = num_device_pages cfg.decode_input_tokens = decode_input_tokens cfg.disable_prefix_cache = disable_prefix_cache + cfg.has_sliding_window = has_sliding_window + cfg.sliding_window_size = sliding_window_size cfg.disable_l2_cache = disable_l2_cache cfg.enable_mamba = enable_mamba diff --git a/python/tokenspeed/runtime/layers/attention/backends/mha.py b/python/tokenspeed/runtime/layers/attention/backends/mha.py index 778951072..767665649 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/mha.py +++ b/python/tokenspeed/runtime/layers/attention/backends/mha.py @@ -345,7 +345,13 @@ def forward_extend( metadata = self.forward_prefill_metadata if metadata.max_extend_prefix_len > 0: - if self.mha_extend_mode == "ragged": + sinks = kwargs.get("sinks") + # gpt-oss uses attention sinks. The ragged split path computes the + # cached-prefix and new-token attention states separately before + # merging; keep sink-attention requests on the unified KV-cache path + # so prefix-cache hits preserve the same numerical path as ordinary + # extend. + if self.mha_extend_mode == "ragged" and sinks is None: return self._forward_extend_split( q, k, @@ -355,7 +361,7 @@ def forward_extend( token_to_kv_pool, metadata, save_kv_cache, - kwargs.get("sinks"), + sinks, ) else: return self._forward_extend( @@ -367,7 +373,7 @@ def forward_extend( token_to_kv_pool, metadata, save_kv_cache, - kwargs.get("sinks"), + sinks, ) return self._forward_prefill( q, diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index e59882531..0b32aedbe 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -964,7 +964,7 @@ def prefix_cache_required_group_ids(self) -> tuple[str, ...]: return tuple( str(spec.group_id) for spec in self.paged_cache_group_specs - if spec.family == "history" + if spec.family in {"history", "state"} ) def bind_paged_cache_scheduler(self, scheduler: object) -> None: diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index eaa825b29..967506424 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -237,6 +237,8 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events) .def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode) .def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache) + .def_rw("has_sliding_window", &tokenspeed::SchedulerConfig::has_sliding_window) + .def_rw("sliding_window_size", &tokenspeed::SchedulerConfig::sliding_window_size) .def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba) .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) .def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks) diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 49b5c45a7..dbab51f74 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -102,12 +102,46 @@ bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, s namespace tokenspeed::fsm { -void InsertHybridCache(HybridPrefixCache* hybrid_cache, +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size, const std::vector* prefix_pages_override) { - if (hybrid_cache == nullptr) return; + std::int32_t page_size, const std::vector* prefix_pages_override, + bool enable_midflight_publish, std::int32_t max_publish_tokens) { + // Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally + // track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV + // prefix cache, making the freshly-computed prefix matchable by concurrent requests + // now rather than only at FinishEvent. Callers disable this explicitly for + // SWA configs that cannot restore the windowed state required for reuse. + auto detach_checkpoint = [&]() { + if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { + local_mamba_allocator->DetachCheckpoint(); + } + }; + if (!enable_midflight_publish) { + detach_checkpoint(); + return; + } + KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache; + if (kv == nullptr) return; + + std::vector> capped_paged_tokens; + const std::vector>* publish_paged_tokens = &full_paged_tokens; + if (max_publish_tokens > 0) { + if (page_size <= 0) { + detach_checkpoint(); + return; + } + const std::int32_t max_publish_pages = max_publish_tokens / page_size; + if (max_publish_pages <= 0) { + detach_checkpoint(); + return; + } + const auto publish_pages = + std::min(max_publish_pages, static_cast(full_paged_tokens.size())); + capped_paged_tokens.assign(full_paged_tokens.begin(), full_paged_tokens.begin() + publish_pages); + publish_paged_tokens = &capped_paged_tokens; + } std::vector computed_prefix_pages; const std::vector* prefix_pages = prefix_pages_override; @@ -116,19 +150,18 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, prefix_pages = &computed_prefix_pages; } std::int32_t new_page_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages->size()); + static_cast(publish_paged_tokens->size()) - static_cast(prefix_pages->size()); if (new_page_count <= 0) { - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { - local_mamba_allocator->DetachCheckpoint(); - } + detach_checkpoint(); return; } OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = hybrid_cache->GetKVPrefixCache().Insert(full_paged_tokens, *prefix_pages, - std::move(pages_to_insert)); + auto insert_result = + kv->Insert(*publish_paged_tokens, *prefix_pages, std::move(pages_to_insert)); - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { + // Mamba checkpoint publication is hybrid-only. + if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) { hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint()); } else { @@ -218,8 +251,9 @@ std::variant SchedulePrefillEvent::operator()(Prefillin if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); // Allocate KV pages for the new chunk local_kv_allocator->Acquire(tokens_this_round_); @@ -268,8 +302,9 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); // Allocate fresh checkpoint for decode-phase mamba state tracking if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { if (!local_mamba_allocator->AllocateCheckpoint()) { diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 095ff7de4..50a9f1f72 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -53,11 +53,19 @@ namespace tokenspeed::fsm { struct PrefetchDone; struct Prefetching; -void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache, +// Publish a request's freshly-computed prefix into the device radix tree *mid-flight* +// (during prefill, at the prefill->decode transition, or after accepted decode +// tokens) so other in-flight requests that share the prefix can reuse it -- instead +// of only after the request finishes (FinishEvent). Works for both the plain KV prefix +// cache (kv_prefix_cache) and, when present, the hybrid cache (which additionally +// publishes the Mamba checkpoint). The published node is pinned via the request's +// device_node_ref so it is not evicted while the request is still using it. +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size, const std::vector* prefix_pages_override = nullptr); + std::int32_t page_size, const std::vector* prefix_pages_override = nullptr, + bool enable_midflight_publish = true, std::int32_t max_publish_tokens = 0); struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); @@ -107,10 +115,14 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr, + bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + hybrid_prefix_cache_(hybrid_prefix_cache), + kv_prefix_cache_(kv_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} // Returns PrefillDone (last chunk) or Prefilling (more chunks remain). std::variant operator()(Prefilling&& state); @@ -119,13 +131,22 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { std::int32_t tokens_this_round_{}; std::int32_t reserve_num_tokens_in_next_schedule_event_{}; HybridPrefixCache* hybrid_prefix_cache_{}; + KVPrefixCache* kv_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr) - : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {} + ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr, + KVPrefixCache* kv_prefix_cache = nullptr, bool enable_midflight_publish = true, + std::int32_t max_midflight_publish_tokens = 0) + : decode_input_tokens_(decode_input_tokens), + hybrid_prefix_cache_(hybrid_prefix_cache), + kv_prefix_cache_(kv_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} Decoding operator()(PrefillDone&& state); Decoding operator()(Decoding&& state); @@ -133,6 +154,9 @@ struct ScheduleDecodeEvent : InvalidTransitionHandler { private: std::int32_t decode_input_tokens_; HybridPrefixCache* hybrid_prefix_cache_{}; + KVPrefixCache* kv_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler { @@ -298,10 +322,14 @@ struct ExtendResultEvent : InvalidTransitionHandler { ExtendResultEvent() = delete; ExtendResultEvent(std::string request_id, std::vector result_tokens, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + KVPrefixCache* kv_prefix_cache = nullptr, HybridPrefixCache* hybrid_prefix_cache = nullptr, + bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0) : request_id_(std::move(request_id)), result_tokens_(std::move(result_tokens)), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + kv_prefix_cache_(kv_prefix_cache), + hybrid_prefix_cache_(hybrid_prefix_cache), + enable_midflight_publish_(enable_midflight_publish), + max_midflight_publish_tokens_(max_midflight_publish_tokens) {} public: template @@ -318,7 +346,7 @@ struct ExtendResultEvent : InvalidTransitionHandler { const std::int32_t page_size = state.GetPageSize(); const std::int32_t reserve = state.GetReserveNumTokensInNextScheduleEvent(); - if (hybrid_prefix_cache_ == nullptr) { + if (kv_prefix_cache_ == nullptr && hybrid_prefix_cache_ == nullptr) { return std::move(state); } @@ -331,28 +359,38 @@ struct ExtendResultEvent : InvalidTransitionHandler { const std::int32_t new_publishable_pages = publishable_pages(accepted_token_size); if (new_publishable_pages <= old_publishable_pages) { - hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); + } return std::move(state); } const std::int32_t chunk_begin = accepted_token_size - static_cast(result_tokens_.size()); auto full_paged_tokens = state.GetFullPagedTokens(/*except_last=*/true); std::vector prefix_pages = DevicePagesFromRoot(state.GetDeviceNode()); - const std::int32_t new_page_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + std::int32_t publish_page_count = static_cast(full_paged_tokens.size()); + if (max_midflight_publish_tokens_ > 0 && page_size > 0) { + publish_page_count = std::min(publish_page_count, max_midflight_publish_tokens_ / page_size); + } + const std::int32_t new_page_count = publish_page_count - static_cast(prefix_pages.size()); auto local_kv_allocator = std::move(state).TakeLocalKVAllocator(); auto local_mamba_allocator = std::move(state).TakeLocalMambaAllocator(); auto device_node_ref = std::move(state).TakeDeviceNodeRef(); auto host_node_ref = std::move(state).TakeHostNodeRef(); - if (new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) { - InsertHybridCache(hybrid_prefix_cache_, full_paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), chunk_begin, - static_cast(result_tokens_.size()), page_size, &prefix_pages); - hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node()); + if (enable_midflight_publish_ && new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) { + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, full_paged_tokens, device_node_ref, + local_kv_allocator.get(), local_mamba_allocator.get(), chunk_begin, + static_cast(result_tokens_.size()), page_size, &prefix_pages, + enable_midflight_publish_, max_midflight_publish_tokens_); + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node()); + } + } + if (hybrid_prefix_cache_ != nullptr) { + hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); } - hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size); return Decoding{token_container, page_size, @@ -374,7 +412,10 @@ struct ExtendResultEvent : InvalidTransitionHandler { private: std::string request_id_; std::vector result_tokens_; + KVPrefixCache* kv_prefix_cache_{}; HybridPrefixCache* hybrid_prefix_cache_{}; + bool enable_midflight_publish_{true}; + std::int32_t max_midflight_publish_tokens_{}; }; } // namespace tokenspeed::fsm diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index fd48962fe..960e803af 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -191,8 +191,14 @@ std::optional Scheduler::schedulePrefill( return {}; } - return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + // Plain SWA models publish only up to the attention window; hybrid + // paged-cache SWA models publish when their adjunct carries window state. + return fsm::SchedulePrefillEvent{tokens_this_round, + reserve_num_tokens_in_next_schedule_event, + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + &kv_prefix_cache_, + enableMidflightPrefixPublish(), + maxMidflightPrefixPublishTokens()}; } std::optional Scheduler::scheduleDecode(Request* request, @@ -218,7 +224,45 @@ std::optional Scheduler::scheduleDecode(Request* reque } return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_, + enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}; +} + +bool Scheduler::enableMidflightPrefixPublish() const { + if (!config_.has_sliding_window) { + return true; + } + if (!hybrid_prefix_cache_ && config_.sliding_window_size > 0) { + return true; + } + if (!hybrid_prefix_cache_ || !hybrid_prefix_cache_->HasPagedCacheAdjunct()) { + return false; + } + if (!config_.prefix_cache_adjunct.has_value()) { + return false; + } + for (const auto& group : config_.paged_cache_groups) { + if (group.family != PagedCacheGroupFamily::State || + group.retention != PagedCacheGroupConfig::Retention::SlidingWindow) { + continue; + } + for (const auto& required_group : config_.prefix_cache_adjunct->required_groups) { + if (required_group == group.group_id) { + return true; + } + } + } + return false; +} + +std::int32_t Scheduler::maxMidflightPrefixPublishTokens() const { + if (!config_.has_sliding_window) { + return 0; + } + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasPagedCacheAdjunct()) { + return 0; + } + return std::max(0, config_.sliding_window_size); } std::optional Scheduler::scheduleDecodeFromRetracted( diff --git a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp index 02e520c45..408561e73 100644 --- a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp @@ -126,8 +126,9 @@ void Scheduler::handleEvent(const forward::UpdateReserveNumTokens& event) { } void Scheduler::handleEvent(const forward::ExtendResult& event) { if (auto req = find_request(event.request_id)) { - req->Apply(fsm::ExtendResultEvent{event.request_id, event.tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + req->Apply(fsm::ExtendResultEvent{event.request_id, event.tokens, &kv_prefix_cache_, + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + enableMidflightPrefixPublish(), maxMidflightPrefixPublishTokens()}); } } diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.h b/tokenspeed-scheduler/csrc/scheduler/scheduler.h index c36c3a413..492eebad9 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.h +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.h @@ -103,6 +103,8 @@ class Scheduler { std::map& simulated_free); std::optional scheduleDecode(Request* request, std::map& simulated_free); + bool enableMidflightPrefixPublish() const; + std::int32_t maxMidflightPrefixPublishTokens() const; std::optional scheduleDecodeFromRetracted( Request* request, std::map& simulated_free); std::optional scheduleRetract(Request* request); diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index 10f723b05..e549e3944 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -101,6 +101,13 @@ struct SchedulerConfig { Role role{Role::kFused}; bool disable_prefix_cache{false}; + // Plain sliding-window-attention models cap mid-flight prefix publish to the + // full-history-equivalent prefix region. Mixed-layer models that need full + // history should leave this false and rely on per-layer attention windows. + bool has_sliding_window{false}; + // Token length of the SWA window. Plain SWA mid-flight publish is capped to + // this many tokens so reuse stays in the full-history-equivalent region. + std::int32_t sliding_window_size{0}; bool enable_mamba{false}; std::int32_t mamba_cache_chunk_size{64}; std::int32_t mamba_pool_total_chunks{0}; diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp index ede2e3f35..d1e033bf9 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_replay.cpp @@ -124,6 +124,70 @@ class PagedCacheDecodePublishTest : public SchedulerTestSuite { } }; +class PagedCacheDecodePublishPlainSwaTest : public SchedulerTestSuite { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = SchedulerTestSuite::MakeConfig(); + cfg.page_size = 1; + cfg.device_allocator.total_pages = 64; + cfg.host_allocator.total_pages = 64; + cfg.max_scheduled_tokens = 64; + cfg.max_batch_size = 8; + cfg.decode_input_tokens = 4; + cfg.enable_l3_storage = false; + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + return cfg; + } + + static const FlatForwardOperation* GetForwardOp(const ExecutionPlan& plan) { + for (const auto& op : plan.Operations()) { + if (auto* f = std::get_if(&op)) return f; + } + return nullptr; + } +}; + +class PagedCacheDecodePublishPlainSwaNoWindowTest : public PagedCacheDecodePublishPlainSwaTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishPlainSwaTest::MakeConfig(); + cfg.sliding_window_size = 0; + return cfg; + } +}; + +class PagedCacheDecodePublishHybridHistorySwaTest : public PagedCacheDecodePublishTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishTest::MakeConfig(); + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + return cfg; + } +}; + +class PagedCacheDecodePublishHybridSwaTest : public PagedCacheDecodePublishTest { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = PagedCacheDecodePublishTest::MakeConfig(); + cfg.has_sliding_window = true; + cfg.sliding_window_size = 4; + + PagedCacheGroupConfig state{}; + state.group_id = "swa_state"; + state.rows_per_page = 1; + state.entry_stride_tokens = 1; + state.total_pages = 64; + state.retention = PagedCacheGroupConfig::Retention::SlidingWindow; + state.sliding_window_tokens = 4; + state.family = PagedCacheGroupFamily::State; + cfg.paged_cache_groups.push_back(state); + cfg.prefix_cache_adjunct->required_groups.push_back("swa_state"); + return cfg; + } +}; + class PagedCacheTerminalContinuationTest : public ::testing::Test { protected: static constexpr std::int32_t kPageSize = 64; @@ -446,6 +510,119 @@ TEST_F(PagedCacheDecodePublishTest, ContinuingDecodePublishesAcceptedPagesOnly) EXPECT_EQ(prefix_by_request.at("probe_tail"), 4); } +TEST_F(PagedCacheDecodePublishPlainSwaTest, MidflightPublishIsCappedToSlidingWindow) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + +TEST_F(PagedCacheDecodePublishPlainSwaTest, DecodeResultPublishesWindowCappedPrefix) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3, 4, 5, 6}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + +TEST_F(PagedCacheDecodePublishPlainSwaNoWindowTest, MidflightPublishIsDisabledWithoutSafeWindow) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3, 4, 5, 6}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 0); +} + +TEST_F(PagedCacheDecodePublishHybridHistorySwaTest, MidflightPublishRequiresWindowState) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5, 6, 7, 8}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 0); +} + +TEST_F(PagedCacheDecodePublishHybridSwaTest, ContinuingDecodePublishesWindowedHybridPrefixes) { + Submit(RequestSpec{.request_id = "r1", .tokens = {1, 2}}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {3}); + ASSERT_NE(GetForwardOp(PlanOnce()), nullptr); + + SendForwardDone("r1", {4, 5}); + + Submit(RequestSpec{.request_id = "hit4", .tokens = {1, 2, 3, 4, 5}}); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_GE(fwd->extend_prefix_lens.size(), 1u); + + std::unordered_map prefix_by_request; + for (std::size_t row = 0; row < fwd->extend_prefix_lens.size(); ++row) { + ASSERT_LT(row, fwd->request_ids.size()); + prefix_by_request.emplace(fwd->request_ids[row], fwd->extend_prefix_lens[row]); + } + + ASSERT_TRUE(prefix_by_request.contains("hit4")); + EXPECT_EQ(prefix_by_request.at("hit4"), 4); +} + TEST_F(PagedCacheTerminalMixedSchedulerTest, MixedPrefillDecodePagedTablesCoverScheduledTokens) { std::vector decode_ids; for (int i = 0; i < 5; ++i) {