-
Notifications
You must be signed in to change notification settings - Fork 166
fix(scheduler): publish prefix to radix tree during prefill for non-hybrid models #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
d5d3fc1
bee0920
70cd61a
4deb07e
0bc9c7b
c08c1a8
c9809f7
bd9ad99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,12 +102,46 @@ bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, s | |
|
|
||
| namespace tokenspeed::fsm { | ||
|
|
||
| void InsertHybridCache(HybridPrefixCache* hybrid_cache, | ||
| void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache, | ||
| const std::vector<std::span<const std::int32_t>>& full_paged_tokens, | ||
| std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator, | ||
| LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, | ||
| std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override) { | ||
| if (hybrid_cache == nullptr) return; | ||
| std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override, | ||
| bool enable_midflight_publish, std::int32_t max_publish_tokens) { | ||
| // Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally | ||
| // track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV | ||
| // prefix cache, making the freshly-computed prefix matchable by concurrent requests | ||
| // now rather than only at FinishEvent. Callers disable this explicitly for | ||
| // SWA configs that cannot restore the windowed state required for reuse. | ||
| auto detach_checkpoint = [&]() { | ||
| if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { | ||
| local_mamba_allocator->DetachCheckpoint(); | ||
| } | ||
| }; | ||
| if (!enable_midflight_publish) { | ||
| detach_checkpoint(); | ||
| return; | ||
| } | ||
| KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For sliding-window models that also create a Useful? React with 👍 / 👎.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be fixed now. The current implementation no longer relies on passing a null base KV cache to suppress publication. The scheduler policy is:
I also added regression coverage for the unsafe hybrid-history SWA case and for capped plain-SWA decode-result publication. |
||
| if (kv == nullptr) return; | ||
|
|
||
| std::vector<std::span<const std::int32_t>> capped_paged_tokens; | ||
| const std::vector<std::span<const std::int32_t>>* publish_paged_tokens = &full_paged_tokens; | ||
| if (max_publish_tokens > 0) { | ||
| if (page_size <= 0) { | ||
| detach_checkpoint(); | ||
| return; | ||
| } | ||
| const std::int32_t max_publish_pages = max_publish_tokens / page_size; | ||
| if (max_publish_pages <= 0) { | ||
| detach_checkpoint(); | ||
| return; | ||
| } | ||
| const auto publish_pages = | ||
| std::min<std::int32_t>(max_publish_pages, static_cast<std::int32_t>(full_paged_tokens.size())); | ||
| capped_paged_tokens.assign(full_paged_tokens.begin(), full_paged_tokens.begin() + publish_pages); | ||
| publish_paged_tokens = &capped_paged_tokens; | ||
| } | ||
|
|
||
| std::vector<std::int32_t> computed_prefix_pages; | ||
| const std::vector<std::int32_t>* prefix_pages = prefix_pages_override; | ||
|
|
@@ -116,19 +150,18 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, | |
| prefix_pages = &computed_prefix_pages; | ||
| } | ||
| std::int32_t new_page_count = | ||
| static_cast<std::int32_t>(full_paged_tokens.size()) - static_cast<std::int32_t>(prefix_pages->size()); | ||
| static_cast<std::int32_t>(publish_paged_tokens->size()) - static_cast<std::int32_t>(prefix_pages->size()); | ||
| if (new_page_count <= 0) { | ||
| if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { | ||
| local_mamba_allocator->DetachCheckpoint(); | ||
| } | ||
| detach_checkpoint(); | ||
| return; | ||
| } | ||
|
|
||
| OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); | ||
| auto insert_result = hybrid_cache->GetKVPrefixCache().Insert<ResourceType::Device>(full_paged_tokens, *prefix_pages, | ||
| std::move(pages_to_insert)); | ||
| auto insert_result = | ||
| kv->Insert<ResourceType::Device>(*publish_paged_tokens, *prefix_pages, std::move(pages_to_insert)); | ||
|
|
||
| if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { | ||
| // Mamba checkpoint publication is hybrid-only. | ||
| if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { | ||
| if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) { | ||
| hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint()); | ||
| } else { | ||
|
|
@@ -218,8 +251,9 @@ std::variant<PrefillDone, Prefilling> SchedulePrefillEvent::operator()(Prefillin | |
| if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) { | ||
| paged_tokens.resize(end_of_window_pages); | ||
| } | ||
| InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), | ||
| local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); | ||
| InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), | ||
| local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), | ||
| /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); | ||
| // Allocate KV pages for the new chunk | ||
| local_kv_allocator->Acquire(tokens_this_round_); | ||
|
|
||
|
|
@@ -268,8 +302,9 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { | |
| if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) { | ||
| paged_tokens.resize(end_of_window_pages); | ||
| } | ||
| InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), | ||
| local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); | ||
| InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), | ||
| local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), | ||
| /*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_); | ||
| // Allocate fresh checkpoint for decode-phase mamba state tracking | ||
| if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { | ||
| if (!local_mamba_allocator->AllocateCheckpoint()) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fresh evidence is that this new scheduler-side detection only reads
hf_config.sliding_window, while the DeepSeek V4 cache spec resolves the same value from eitherhf_configorhf_config.text_config. For checkpoints that putsliding_windowundertext_config,has_sliding_windowstays false, andScheduler::enableMidflightPrefixPublish()returns true before checking for a paged-cache adjunct/window state, re-enabling mid-flight publication for hybrid SWA configurations that this change is trying to guard. Please mirror the nested lookup here before derivinghas_sliding_window.Useful? React with 👍 / 👎.