Skip to content

Commit 8373813

Browse files
committed
fix(scheduler): publish prefix to radix tree during prefill for non-hybrid models
A request's prompt-prefix KV was inserted into the shared device radix tree only at FinishEvent for non-hybrid models: the mid-flight InsertHybridCache early-returned when hybrid_prefix_cache_ was null (every non-DeepSeek-V4/Mamba model). A burst of concurrent requests sharing a prefix (RL rollouts with N samples/prompt, or a shared chat-template/system prefix) therefore all prefilled before any finished -> ~0% prefix-cache reuse, vs ~26% for SGLang which publishes during prefill (cache_unfinished_req). Rename InsertHybridCache -> InsertPrefixCache; publish the freshly-computed prefix through the base KV prefix cache when there is no hybrid cache (hybrid path unchanged: still via hybrid_cache->GetKVPrefixCache()). The node is pinned via the request's DeviceNodeRef so it is not evicted while in use; Mamba checkpoint publication stays hybrid-only. Thread kv_prefix_cache_ into SchedulePrefillEvent and ScheduleDecodeEvent so they can publish for non-hybrid models. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
1 parent 4b87a50 commit 8373813

3 files changed

Lines changed: 32 additions & 14 deletions

File tree

tokenspeed-scheduler/csrc/fsm/forward_events.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,17 @@ bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, s
102102

103103
namespace tokenspeed::fsm {
104104

105-
void InsertHybridCache(HybridPrefixCache* hybrid_cache,
105+
void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache,
106106
const std::vector<std::span<const std::int32_t>>& full_paged_tokens,
107107
std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator,
108108
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
109109
std::int32_t page_size) {
110-
if (hybrid_cache == nullptr) return;
110+
// Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally
111+
// track a Mamba checkpoint); plain models publish through the base KV prefix cache.
112+
// Either way the freshly-computed prefix becomes matchable by concurrent requests
113+
// now, rather than only when this request finishes (FinishEvent).
114+
KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache;
115+
if (kv == nullptr) return;
111116

112117
std::vector<std::int32_t> prefix_pages = DevicePagesFromRoot(device_node_ref->Node());
113118
std::int32_t new_page_count =
@@ -120,10 +125,10 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache,
120125
}
121126

122127
OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count);
123-
auto insert_result = hybrid_cache->GetKVPrefixCache().Insert<ResourceType::Device>(full_paged_tokens, prefix_pages,
124-
std::move(pages_to_insert));
128+
auto insert_result = kv->Insert<ResourceType::Device>(full_paged_tokens, prefix_pages, std::move(pages_to_insert));
125129

126-
if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
130+
// Mamba checkpoint publication is hybrid-only.
131+
if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
127132
if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) {
128133
hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint());
129134
} else {
@@ -213,7 +218,7 @@ std::variant<PrefillDone, Prefilling> SchedulePrefillEvent::operator()(Prefillin
213218
if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) {
214219
paged_tokens.resize(end_of_window_pages);
215220
}
216-
InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
221+
InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
217222
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize());
218223
// Allocate KV pages for the new chunk
219224
local_kv_allocator->Acquire(tokens_this_round_);
@@ -263,7 +268,7 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) {
263268
if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) {
264269
paged_tokens.resize(end_of_window_pages);
265270
}
266-
InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
271+
InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
267272
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize());
268273
// Allocate fresh checkpoint for decode-phase mamba state tracking
269274
if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) {

tokenspeed-scheduler/csrc/fsm/forward_events.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ namespace tokenspeed::fsm {
5252
struct PrefetchDone;
5353
struct Prefetching;
5454

55-
void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache,
55+
// Publish a request's freshly-computed prefix into the device radix tree *mid-flight*
56+
// (during prefill / at the prefill->decode transition) so other in-flight requests that
57+
// share the prefix can reuse it -- instead of only after the request finishes
58+
// (FinishEvent). Works for both the plain KV prefix cache (kv_prefix_cache) and, when
59+
// present, the hybrid cache (which additionally publishes the Mamba checkpoint). The
60+
// published node is pinned via the request's device_node_ref so it is not evicted while
61+
// the request is still using it.
62+
void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache,
5663
const std::vector<std::span<const std::int32_t>>& full_paged_tokens,
5764
std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator,
5865
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
@@ -106,10 +113,11 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler<SchedulePrefill
106113
struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
107114
using InvalidTransitionHandler<SchedulePrefillEvent>::operator();
108115
SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event,
109-
HybridPrefixCache* hybrid_prefix_cache = nullptr)
116+
HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr)
110117
: tokens_this_round_(tokens_this_round),
111118
reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event),
112-
hybrid_prefix_cache_(hybrid_prefix_cache) {}
119+
hybrid_prefix_cache_(hybrid_prefix_cache),
120+
kv_prefix_cache_(kv_prefix_cache) {}
113121

114122
// Returns PrefillDone (last chunk) or Prefilling (more chunks remain).
115123
std::variant<PrefillDone, Prefilling> operator()(Prefilling&& state);
@@ -118,20 +126,25 @@ struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
118126
std::int32_t tokens_this_round_{};
119127
std::int32_t reserve_num_tokens_in_next_schedule_event_{};
120128
HybridPrefixCache* hybrid_prefix_cache_{};
129+
KVPrefixCache* kv_prefix_cache_{};
121130
};
122131

123132
struct ScheduleDecodeEvent : InvalidTransitionHandler<ScheduleDecodeEvent> {
124133
using InvalidTransitionHandler<ScheduleDecodeEvent>::operator();
125134

126-
ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr)
127-
: decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {}
135+
ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr,
136+
KVPrefixCache* kv_prefix_cache = nullptr)
137+
: decode_input_tokens_(decode_input_tokens),
138+
hybrid_prefix_cache_(hybrid_prefix_cache),
139+
kv_prefix_cache_(kv_prefix_cache) {}
128140

129141
Decoding operator()(PrefillDone&& state);
130142
Decoding operator()(Decoding&& state);
131143

132144
private:
133145
std::int32_t decode_input_tokens_;
134146
HybridPrefixCache* hybrid_prefix_cache_{};
147+
KVPrefixCache* kv_prefix_cache_{};
135148
};
136149

137150
struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler<ScheduleDecodeFromRetractedEvent> {

tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ std::optional<fsm::SchedulePrefillEvent> Scheduler::schedulePrefill(
192192
}
193193

194194
return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event,
195-
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr};
195+
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_};
196196
}
197197

198198
std::optional<fsm::ScheduleDecodeEvent> Scheduler::scheduleDecode(Request* request,
@@ -218,7 +218,7 @@ std::optional<fsm::ScheduleDecodeEvent> Scheduler::scheduleDecode(Request* reque
218218
}
219219

220220
return fsm::ScheduleDecodeEvent{config_.decode_input_tokens,
221-
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr};
221+
hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, &kv_prefix_cache_};
222222
}
223223

224224
std::optional<fsm::ScheduleDecodeFromRetractedEvent> Scheduler::scheduleDecodeFromRetracted(

0 commit comments

Comments
 (0)