Skip to content
15 changes: 15 additions & 0 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,19 @@ def __init__(
has_mamba = getattr(self.model_config, "mambaish_config", None) is not None or (
text_config is not None and hasattr(text_config, "mamba2_cache_params")
)
# Mirror ModelRunner's SWA detection (hf_config.sliding_window). Plain
# SWA mid-flight publish is capped to this window; hybrid paged-cache
# models use their windowed adjunct snapshots. gpt-oss stores an
# inclusive HF window and converts it to TokenSpeed's exclusive
# attention window inside the model.
sliding_window = getattr(hf_config, "sliding_window", None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Detect nested sliding windows before enabling publish

Fresh evidence is that this new scheduler-side detection only reads hf_config.sliding_window, while the DeepSeek V4 cache spec resolves the same value from either hf_config or hf_config.text_config. For checkpoints that put sliding_window under text_config, has_sliding_window stays false, and Scheduler::enableMidflightPrefixPublish() returns true before checking for a paged-cache adjunct/window state, re-enabling mid-flight publication for hybrid SWA configurations that this change is trying to guard. Please mirror the nested lookup here before deriving has_sliding_window.

Useful? React with 👍 / 👎.

if (
getattr(hf_config, "model_type", None) == "gpt_oss"
and sliding_window is not None
):
sliding_window = max(0, int(sliding_window) - 1)
has_sliding_window = sliding_window is not None
sliding_window_size = int(sliding_window) if sliding_window is not None else 0

model_executor_config = ModelExecutorConfig.from_server_args(
server_args=server_args,
Expand Down Expand Up @@ -327,6 +340,8 @@ def __init__(
else 1
),
disable_prefix_cache=not server_args.enable_prefix_caching,
has_sliding_window=has_sliding_window,
sliding_window_size=sliding_window_size,
enable_mamba=has_mamba,
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
mamba_pool_total_chunks=mamba_pool_total_chunks,
Expand Down
4 changes: 4 additions & 0 deletions python/tokenspeed/runtime/engine/scheduler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def make_config(
enable_kv_cache_events: bool = False,
decode_input_tokens: int = 1,
disable_prefix_cache: bool = False,
has_sliding_window: bool = False,
sliding_window_size: int = 0,
enable_mamba: bool = False,
mamba_cache_chunk_size: int = 64,
mamba_pool_total_chunks: int = 0,
Expand Down Expand Up @@ -93,6 +95,8 @@ def make_config(
cfg.num_device_pages = num_device_pages
cfg.decode_input_tokens = decode_input_tokens
cfg.disable_prefix_cache = disable_prefix_cache
cfg.has_sliding_window = has_sliding_window
cfg.sliding_window_size = sliding_window_size
cfg.disable_l2_cache = disable_l2_cache

cfg.enable_mamba = enable_mamba
Expand Down
12 changes: 9 additions & 3 deletions python/tokenspeed/runtime/layers/attention/backends/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,13 @@ def forward_extend(

metadata = self.forward_prefill_metadata
if metadata.max_extend_prefix_len > 0:
if self.mha_extend_mode == "ragged":
sinks = kwargs.get("sinks")
# gpt-oss uses attention sinks. The ragged split path computes the
# cached-prefix and new-token attention states separately before
# merging; keep sink-attention requests on the unified KV-cache path
# so prefix-cache hits preserve the same numerical path as ordinary
# extend.
if self.mha_extend_mode == "ragged" and sinks is None:
return self._forward_extend_split(
q,
k,
Expand All @@ -355,7 +361,7 @@ def forward_extend(
token_to_kv_pool,
metadata,
save_kv_cache,
kwargs.get("sinks"),
sinks,
)
else:
return self._forward_extend(
Expand All @@ -367,7 +373,7 @@ def forward_extend(
token_to_kv_pool,
metadata,
save_kv_cache,
kwargs.get("sinks"),
sinks,
)
return self._forward_prefill(
q,
Expand Down
2 changes: 2 additions & 0 deletions tokenspeed-scheduler/bindings/python_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ NB_MODULE(tokenspeed_scheduler_ext, m) {
.def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events)
.def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode)
.def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache)
.def_rw("has_sliding_window", &tokenspeed::SchedulerConfig::has_sliding_window)
.def_rw("sliding_window_size", &tokenspeed::SchedulerConfig::sliding_window_size)
.def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba)
.def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size)
.def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks)
Expand Down
63 changes: 49 additions & 14 deletions tokenspeed-scheduler/csrc/fsm/forward_events.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,46 @@ bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, s

namespace tokenspeed::fsm {

void InsertHybridCache(HybridPrefixCache* hybrid_cache,
void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache,
const std::vector<std::span<const std::int32_t>>& full_paged_tokens,
std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator,
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override) {
if (hybrid_cache == nullptr) return;
std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override,
bool enable_midflight_publish, std::int32_t max_publish_tokens) {
// Hybrid models publish through the hybrid cache's wrapped KV cache (and additionally
// track a Mamba checkpoint); plain (non-hybrid) models publish through the base KV
// prefix cache, making the freshly-computed prefix matchable by concurrent requests
// now rather than only at FinishEvent. Callers disable this explicitly for
// SWA configs that cannot restore the windowed state required for reuse.
auto detach_checkpoint = [&]() {
if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
local_mamba_allocator->DetachCheckpoint();
}
};
if (!enable_midflight_publish) {
detach_checkpoint();
return;
}
KVPrefixCache* kv = (hybrid_cache != nullptr) ? &hybrid_cache->GetKVPrefixCache() : kv_prefix_cache;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor the SWA guard when a hybrid cache exists

For sliding-window models that also create a HybridPrefixCache (the scheduler does this whenever paged-cache groups or a prefix-cache adjunct are configured, e.g. the DeepSeek V4 SWA/state groups), schedulePrefill/scheduleDecode pass nullptr as the KV cache, but this line still selects hybrid_cache->GetKVPrefixCache() and publishes mid-flight. In the scheduler paths I checked, that leaves the newly documented corrupt SWA prefix-reuse path enabled for hybrid/SWA models; the helper needs an explicit skip signal rather than relying on a null base KV pointer.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed now. The current implementation no longer relies on passing a null base KV cache to suppress publication. InsertPrefixCache() now takes an explicit enable_midflight_publish flag and returns before selecting hybrid_cache->GetKVPrefixCache() when the SWA path is not safe.

The scheduler policy is:

  • non-SWA: publish mid-flight as before
  • plain SWA / gpt-oss: publish only up to the configured sliding-window cap
  • hybrid history-only SWA: do not publish mid-flight
  • hybrid paged-cache SWA / DeepSeek V4: publish only when the paged-cache adjunct has sliding-window State groups, with HybridPrefixCache::Match() still enforcing snapshot/window correctness

I also added regression coverage for the unsafe hybrid-history SWA case and for capped plain-SWA decode-result publication.

if (kv == nullptr) return;

std::vector<std::span<const std::int32_t>> capped_paged_tokens;
const std::vector<std::span<const std::int32_t>>* publish_paged_tokens = &full_paged_tokens;
if (max_publish_tokens > 0) {
if (page_size <= 0) {
detach_checkpoint();
return;
}
const std::int32_t max_publish_pages = max_publish_tokens / page_size;
if (max_publish_pages <= 0) {
detach_checkpoint();
return;
}
const auto publish_pages =
std::min<std::int32_t>(max_publish_pages, static_cast<std::int32_t>(full_paged_tokens.size()));
capped_paged_tokens.assign(full_paged_tokens.begin(), full_paged_tokens.begin() + publish_pages);
publish_paged_tokens = &capped_paged_tokens;
}

std::vector<std::int32_t> computed_prefix_pages;
const std::vector<std::int32_t>* prefix_pages = prefix_pages_override;
Expand All @@ -116,19 +150,18 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache,
prefix_pages = &computed_prefix_pages;
}
std::int32_t new_page_count =
static_cast<std::int32_t>(full_paged_tokens.size()) - static_cast<std::int32_t>(prefix_pages->size());
static_cast<std::int32_t>(publish_paged_tokens->size()) - static_cast<std::int32_t>(prefix_pages->size());
if (new_page_count <= 0) {
if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
local_mamba_allocator->DetachCheckpoint();
}
detach_checkpoint();
return;
}

OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count);
auto insert_result = hybrid_cache->GetKVPrefixCache().Insert<ResourceType::Device>(full_paged_tokens, *prefix_pages,
std::move(pages_to_insert));
auto insert_result =
kv->Insert<ResourceType::Device>(*publish_paged_tokens, *prefix_pages, std::move(pages_to_insert));

if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
// Mamba checkpoint publication is hybrid-only.
if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) {
if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) {
hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint());
} else {
Expand Down Expand Up @@ -218,8 +251,9 @@ std::variant<PrefillDone, Prefilling> SchedulePrefillEvent::operator()(Prefillin
if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) {
paged_tokens.resize(end_of_window_pages);
}
InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize());
InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(),
/*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_);
// Allocate KV pages for the new chunk
local_kv_allocator->Acquire(tokens_this_round_);

Expand Down Expand Up @@ -268,8 +302,9 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) {
if (end_of_window_pages < static_cast<std::int32_t>(paged_tokens.size())) {
paged_tokens.resize(end_of_window_pages);
}
InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize());
InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(),
local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(),
/*prefix_pages_override=*/nullptr, enable_midflight_publish_, max_midflight_publish_tokens_);
// Allocate fresh checkpoint for decode-phase mamba state tracking
if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) {
if (!local_mamba_allocator->AllocateCheckpoint()) {
Expand Down
77 changes: 59 additions & 18 deletions tokenspeed-scheduler/csrc/fsm/forward_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,19 @@ namespace tokenspeed::fsm {
struct PrefetchDone;
struct Prefetching;

void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache,
// Publish a request's freshly-computed prefix into the device radix tree *mid-flight*
// (during prefill, at the prefill->decode transition, or after accepted decode
// tokens) so other in-flight requests that share the prefix can reuse it -- instead
// of only after the request finishes (FinishEvent). Works for both the plain KV prefix
// cache (kv_prefix_cache) and, when present, the hybrid cache (which additionally
// publishes the Mamba checkpoint). The published node is pinned via the request's
// device_node_ref so it is not evicted while the request is still using it.
void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache,
const std::vector<std::span<const std::int32_t>>& full_paged_tokens,
std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator,
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override = nullptr);
std::int32_t page_size, const std::vector<std::int32_t>* prefix_pages_override = nullptr,
bool enable_midflight_publish = true, std::int32_t max_publish_tokens = 0);

struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler<SchedulePrefillFirstChunkEvent> {
using InvalidTransitionHandler<SchedulePrefillFirstChunkEvent>::operator();
Expand Down Expand Up @@ -107,10 +115,14 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler<SchedulePrefill
struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
using InvalidTransitionHandler<SchedulePrefillEvent>::operator();
SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event,
HybridPrefixCache* hybrid_prefix_cache = nullptr)
HybridPrefixCache* hybrid_prefix_cache = nullptr, KVPrefixCache* kv_prefix_cache = nullptr,
bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0)
: tokens_this_round_(tokens_this_round),
reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event),
hybrid_prefix_cache_(hybrid_prefix_cache) {}
hybrid_prefix_cache_(hybrid_prefix_cache),
kv_prefix_cache_(kv_prefix_cache),
enable_midflight_publish_(enable_midflight_publish),
max_midflight_publish_tokens_(max_midflight_publish_tokens) {}

// Returns PrefillDone (last chunk) or Prefilling (more chunks remain).
std::variant<PrefillDone, Prefilling> operator()(Prefilling&& state);
Expand All @@ -119,20 +131,32 @@ struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
std::int32_t tokens_this_round_{};
std::int32_t reserve_num_tokens_in_next_schedule_event_{};
HybridPrefixCache* hybrid_prefix_cache_{};
KVPrefixCache* kv_prefix_cache_{};
bool enable_midflight_publish_{true};
std::int32_t max_midflight_publish_tokens_{};
};

struct ScheduleDecodeEvent : InvalidTransitionHandler<ScheduleDecodeEvent> {
using InvalidTransitionHandler<ScheduleDecodeEvent>::operator();

ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr)
: decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {}
ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr,
KVPrefixCache* kv_prefix_cache = nullptr, bool enable_midflight_publish = true,
std::int32_t max_midflight_publish_tokens = 0)
: decode_input_tokens_(decode_input_tokens),
hybrid_prefix_cache_(hybrid_prefix_cache),
kv_prefix_cache_(kv_prefix_cache),
enable_midflight_publish_(enable_midflight_publish),
max_midflight_publish_tokens_(max_midflight_publish_tokens) {}

Decoding operator()(PrefillDone&& state);
Decoding operator()(Decoding&& state);

private:
std::int32_t decode_input_tokens_;
HybridPrefixCache* hybrid_prefix_cache_{};
KVPrefixCache* kv_prefix_cache_{};
bool enable_midflight_publish_{true};
std::int32_t max_midflight_publish_tokens_{};
};

struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler<ScheduleDecodeFromRetractedEvent> {
Expand Down Expand Up @@ -298,10 +322,14 @@ struct ExtendResultEvent : InvalidTransitionHandler<ExtendResultEvent> {
ExtendResultEvent() = delete;

ExtendResultEvent(std::string request_id, std::vector<std::int32_t> result_tokens,
HybridPrefixCache* hybrid_prefix_cache = nullptr)
KVPrefixCache* kv_prefix_cache = nullptr, HybridPrefixCache* hybrid_prefix_cache = nullptr,
bool enable_midflight_publish = true, std::int32_t max_midflight_publish_tokens = 0)
: request_id_(std::move(request_id)),
result_tokens_(std::move(result_tokens)),
hybrid_prefix_cache_(hybrid_prefix_cache) {}
kv_prefix_cache_(kv_prefix_cache),
hybrid_prefix_cache_(hybrid_prefix_cache),
enable_midflight_publish_(enable_midflight_publish),
max_midflight_publish_tokens_(max_midflight_publish_tokens) {}

public:
template <typename S>
Expand All @@ -318,7 +346,7 @@ struct ExtendResultEvent : InvalidTransitionHandler<ExtendResultEvent> {
const std::int32_t page_size = state.GetPageSize();
const std::int32_t reserve = state.GetReserveNumTokensInNextScheduleEvent();

if (hybrid_prefix_cache_ == nullptr) {
if (kv_prefix_cache_ == nullptr && hybrid_prefix_cache_ == nullptr) {
return std::move(state);
}

Expand All @@ -331,28 +359,38 @@ struct ExtendResultEvent : InvalidTransitionHandler<ExtendResultEvent> {
const std::int32_t new_publishable_pages = publishable_pages(accepted_token_size);

if (new_publishable_pages <= old_publishable_pages) {
hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size);
if (hybrid_prefix_cache_ != nullptr) {
hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size);
}
return std::move(state);
}

const std::int32_t chunk_begin = accepted_token_size - static_cast<std::int32_t>(result_tokens_.size());
auto full_paged_tokens = state.GetFullPagedTokens(/*except_last=*/true);
std::vector<std::int32_t> prefix_pages = DevicePagesFromRoot(state.GetDeviceNode());
const std::int32_t new_page_count =
static_cast<std::int32_t>(full_paged_tokens.size()) - static_cast<std::int32_t>(prefix_pages.size());
std::int32_t publish_page_count = static_cast<std::int32_t>(full_paged_tokens.size());
if (max_midflight_publish_tokens_ > 0 && page_size > 0) {
publish_page_count = std::min(publish_page_count, max_midflight_publish_tokens_ / page_size);
}
const std::int32_t new_page_count = publish_page_count - static_cast<std::int32_t>(prefix_pages.size());

auto local_kv_allocator = std::move(state).TakeLocalKVAllocator();
auto local_mamba_allocator = std::move(state).TakeLocalMambaAllocator();
auto device_node_ref = std::move(state).TakeDeviceNodeRef();
auto host_node_ref = std::move(state).TakeHostNodeRef();

if (new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) {
InsertHybridCache(hybrid_prefix_cache_, full_paged_tokens, device_node_ref, local_kv_allocator.get(),
local_mamba_allocator.get(), chunk_begin,
static_cast<std::int32_t>(result_tokens_.size()), page_size, &prefix_pages);
hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node());
if (enable_midflight_publish_ && new_page_count > 0 && local_kv_allocator->PageCount() >= new_page_count) {
InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, full_paged_tokens, device_node_ref,
local_kv_allocator.get(), local_mamba_allocator.get(), chunk_begin,
static_cast<std::int32_t>(result_tokens_.size()), page_size, &prefix_pages,
enable_midflight_publish_, max_midflight_publish_tokens_);
if (hybrid_prefix_cache_ != nullptr) {
hybrid_prefix_cache_->CommitChunk(request_id_, device_node_ref->Node());
}
}
if (hybrid_prefix_cache_ != nullptr) {
hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size);
}
hybrid_prefix_cache_->RewindRequest(request_id_, accepted_token_size);

return Decoding{token_container,
page_size,
Expand All @@ -374,7 +412,10 @@ struct ExtendResultEvent : InvalidTransitionHandler<ExtendResultEvent> {
private:
std::string request_id_;
std::vector<std::int32_t> result_tokens_;
KVPrefixCache* kv_prefix_cache_{};
HybridPrefixCache* hybrid_prefix_cache_{};
bool enable_midflight_publish_{true};
std::int32_t max_midflight_publish_tokens_{};
};

} // namespace tokenspeed::fsm
Loading
Loading