Skip to content

Commit c774bf0

Browse files
HJSangclaude
andcommitted
feat(logprobs): vLLM-style logprob API (LogprobParams + Logprob output shape)
Replace the scattered return_logprob/logprob_start_len/top_logprobs_num/ token_ids_logprob/return_text_in_logprobs request fields with a dedicated LogprobParams structure (logprobs, prompt_logprobs, logprob_token_ids, return_text), and adopt a vLLM-style Logprob{logprob, rank, decoded_token} output shape (logprobs + prompt_logprobs as dict[int, Logprob] per position). - LogprobParams with num_logprobs/num_prompt_logprobs/requested/verify (rejects -1 full-vocab until supported); SamplingParams stays sampling-only. - C++ scheduler: prompt_logprobs + logprob_token_ids on RequestSpec / ForwardOperation / FlatForwardOperation + pybind; MatchIntent::SkipRead bypasses prefix reuse for prompt-logprob requests (cache hits otherwise truncate prompt logprobs). - Forward path: input/prompt-logprob path (base + top-k + specific token-ids) activated for pure-extend batches, routed through ModelExecutionResult and assembled per-request; output (sampled-token) logprobs driven by LogprobParams. - Hard cut of the old fields. Validated on nv2 B200 vs HF log_softmax: output logprobs=0 (Qwen2-1.5B & Qwen3-0.6B), prompt_logprobs=0 (vLLM convention), prompt top-k (=5), prompt logprob_token_ids, and prefix-cache correctness. Follow-ups: OUTPUT top-k / token-id logprobs (needs CUDA-graph capture), mixed prefill+decode prompt logprobs, scheduler/PD internal renames, gateway (smg proto + SMG param mapping). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Hejian Sang <sanghj0923@gmail.com>
1 parent 1780377 commit c774bf0

27 files changed

Lines changed: 915 additions & 327 deletions

python/tokenspeed/runtime/engine/collector.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,24 @@
2525
from collections.abc import Sequence
2626
from typing import Any
2727

28+
# Streaming merge policy for the logprob meta_info fields.
29+
#
30+
# ``logprobs`` and ``prompt_logprobs`` are each a list[dict[int, Logprob]] (one
31+
# entry per token position) that GROWS as frames arrive, so they must be
32+
# appended rather than overwritten -- hence they are listed here and merged by
33+
# ``_extend_sequence``. That helper is dict-safe: its prefix check
34+
# (``_is_prefix``) compares elements only with ``==``, which both ``dict`` and
35+
# the ``Logprob`` dataclass support, so the entries never need to be hashed or
36+
# ordered.
37+
#
38+
# ``cumulative_logprob`` is deliberately NOT listed: it is a scalar, so it takes
39+
# the default overwrite path (latest frame wins) instead. Under streaming each
40+
# frame recomputes it from a fresh dict, so the emitted value reflects only that
41+
# frame's positions; clients that need the running total should sum the
42+
# per-position ``logprobs`` entries themselves.
2843
_APPEND_META_KEYS = {
29-
"input_token_logprobs",
30-
"output_token_logprobs",
31-
"input_top_logprobs",
32-
"output_top_logprobs",
33-
"input_token_ids_logprobs",
34-
"output_token_ids_logprobs",
44+
"logprobs",
45+
"prompt_logprobs",
3546
}
3647

3748

python/tokenspeed/runtime/engine/generation_output_processor.py

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
token_ids_logprob: list[int] | None = None,
7575
multimodal_inputs=None,
7676
prompt_input_ids_unpadded: list[int] | None = None,
77+
return_prompt_logprob: bool = False,
7778
) -> None:
7879
# --- Extracted from recv_req (immutable) ---
7980
self.prompt_input_ids: list[int] = prompt_input_ids
@@ -91,6 +92,7 @@ def __init__(
9192
self.return_logprob = return_logprob
9293
self.top_logprobs_num = top_logprobs_num
9394
self.token_ids_logprob = token_ids_logprob
95+
self.return_prompt_logprob = return_prompt_logprob
9496

9597
# --- generation state (updated with forward step) ---
9698
self.output_ids: list[int] = []
@@ -107,6 +109,30 @@ def __init__(
107109
self.output_token_logprobs_idx: list[int] | None = (
108110
[] if return_logprob else None
109111
)
112+
# Prompt/input-token logprobs, accumulated across (possibly chunked)
113+
# prefill ops. None unless the request asked for prompt_logprobs.
114+
self.input_token_logprobs_val: list[float] | None = (
115+
[] if return_prompt_logprob else None
116+
)
117+
self.input_token_logprobs_idx: list[int] | None = (
118+
[] if return_prompt_logprob else None
119+
)
120+
# Per-position top-k prompt logprobs (prompt_logprobs>0). Each entry is
121+
# a [k] list parallel to the base prompt-logprob accumulators above.
122+
self.input_top_logprobs_val: list | None = [] if return_prompt_logprob else None
123+
self.input_top_logprobs_idx: list | None = [] if return_prompt_logprob else None
124+
# Per-position token-id prompt logprobs (logprob_token_ids requested).
125+
# Each entry is a [num_token_ids] list parallel to the base prompt-logprob
126+
# accumulators above.
127+
self.input_token_ids_logprobs_val: list | None = (
128+
[] if return_prompt_logprob else None
129+
)
130+
self.input_token_ids_logprobs_idx: list | None = (
131+
[] if return_prompt_logprob else None
132+
)
133+
# Number of prompt-logprob entries already shipped to the detokenizer,
134+
# so multi-chunk / multi-step streaming only sends the un-shipped tail.
135+
self.sent_prompt_logprob_offset: int = 0
110136

111137
# --- Streaming bookkeeping (internal) ---
112138
self._surr_offset: int | None = None
@@ -144,17 +170,27 @@ def from_recv_req(
144170
tokenizer,
145171
eos_token_ids: list[int],
146172
) -> RequestState:
173+
# Translate the LogprobParams into the scheduler-internal
174+
# scalar attrs the rest of the pipeline already understands. Output-token
175+
# logprob collection is gated on ``return_logprob`` below; the prompt
176+
# (off-policy) path is wired separately in full Phase B.
177+
lp = getattr(recv_req, "logprob_params", None)
178+
return_logprob = lp is not None and lp.num_logprobs() is not None
179+
top_logprobs_num = lp.num_logprobs() or 0 if return_logprob else 0
180+
token_ids_logprob = lp.logprob_token_ids if lp is not None else None
181+
return_prompt_logprob = lp is not None and lp.num_prompt_logprobs() is not None
147182
return cls(
148183
prompt_input_ids=recv_req.input_ids,
149184
sampling_params=recv_req.sampling_params,
150185
stream=recv_req.stream,
151186
tokenizer=tokenizer,
152187
eos_token_ids=eos_token_ids,
153-
return_logprob=getattr(recv_req, "return_logprob", False),
154-
top_logprobs_num=getattr(recv_req, "top_logprobs_num", 0),
155-
token_ids_logprob=getattr(recv_req, "token_ids_logprob", None),
188+
return_logprob=return_logprob,
189+
top_logprobs_num=top_logprobs_num,
190+
token_ids_logprob=token_ids_logprob,
156191
multimodal_inputs=getattr(recv_req, "multimodal_inputs", None),
157192
prompt_input_ids_unpadded=getattr(recv_req, "input_ids_unpadded", None),
193+
return_prompt_logprob=return_prompt_logprob,
158194
)
159195

160196
@property
@@ -521,6 +557,31 @@ def post_process_forward_op(
521557
if model_execution_results.output_logprobs is not None
522558
else None
523559
)
560+
_input_lp = getattr(model_execution_results, "input_token_logprobs", None)
561+
input_logprobs_list = _input_lp.tolist() if _input_lp is not None else None
562+
_input_lp_ids = getattr(
563+
model_execution_results, "input_logprob_token_ids", None
564+
)
565+
input_logprob_ids_list = (
566+
_input_lp_ids.tolist() if _input_lp_ids is not None else None
567+
)
568+
# Top-k prompt logprobs are already CPU Python lists partitioned per
569+
# extend request (index by extend request i, NOT the flat ilp_pt).
570+
input_top_val_list = getattr(
571+
model_execution_results, "input_top_logprobs_val", None
572+
)
573+
input_top_idx_list = getattr(
574+
model_execution_results, "input_top_logprobs_idx", None
575+
)
576+
# Token-id prompt logprobs are likewise CPU Python lists partitioned per
577+
# extend request (index by extend request i, NOT the flat ilp_pt).
578+
input_tid_val_list = getattr(
579+
model_execution_results, "input_token_ids_logprobs_val", None
580+
)
581+
input_tid_idx_list = getattr(
582+
model_execution_results, "input_token_ids_logprobs_idx", None
583+
)
584+
ilp_pt = 0
524585
pt = 0
525586
for i, rid in enumerate(forward_op.request_ids):
526587
output_length = model_execution_results.output_lengths[i].item()
@@ -538,12 +599,62 @@ def post_process_forward_op(
538599
else:
539600
pt += output_length
540601

602+
# Prompt/input-token logprobs (pure-extend prompt_logprobs path).
603+
# The flat array carries one entry per scored prefill position for
604+
# every i < num_extends (the ctx activates input logprobs for the
605+
# whole pure-extend batch), so advance ilp_pt unconditionally here —
606+
# before the rid_to_state / prefill-finished guards — to keep the
607+
# pointer aligned with the flat array even for delayed/chunked rows.
608+
seg_val = None
609+
seg_idx = None
610+
pl_req = -1
611+
if input_logprobs_list is not None and i < num_extends:
612+
plen = int(forward_op.input_lengths[i])
613+
pl_req = int(forward_op.prompt_logprobs[i])
614+
seg_val = input_logprobs_list[ilp_pt : ilp_pt + plen]
615+
seg_idx = (
616+
input_logprob_ids_list[ilp_pt : ilp_pt + plen]
617+
if input_logprob_ids_list is not None
618+
else []
619+
)
620+
ilp_pt += plen
621+
541622
if rid not in self.rid_to_state:
542623
# means it's delayed token, do not process
543624
continue
544625

545626
request_state: RequestState = self.rid_to_state[rid]
546627

628+
if (
629+
seg_val is not None
630+
and pl_req >= 0
631+
and request_state.input_token_logprobs_val is not None
632+
):
633+
request_state.input_token_logprobs_val.extend(seg_val)
634+
request_state.input_token_logprobs_idx.extend(seg_idx)
635+
if (
636+
input_top_val_list is not None
637+
and i < len(input_top_val_list)
638+
and request_state.input_top_logprobs_val is not None
639+
):
640+
request_state.input_top_logprobs_val.extend(input_top_val_list[i])
641+
if input_top_idx_list is not None and i < len(input_top_idx_list):
642+
request_state.input_top_logprobs_idx.extend(
643+
input_top_idx_list[i]
644+
)
645+
if (
646+
input_tid_val_list is not None
647+
and i < len(input_tid_val_list)
648+
and request_state.input_token_ids_logprobs_val is not None
649+
):
650+
request_state.input_token_ids_logprobs_val.extend(
651+
input_tid_val_list[i]
652+
)
653+
if input_tid_idx_list is not None and i < len(input_tid_idx_list):
654+
request_state.input_token_ids_logprobs_idx.extend(
655+
input_tid_idx_list[i]
656+
)
657+
547658
# Do not output chunking result
548659
if not request_state.prefill_finished:
549660
continue
@@ -719,6 +830,12 @@ def stream_output(
719830
output_extra_infos: list[dict] = []
720831
output_token_logprobs_val: list[list[float]] = []
721832
output_token_logprobs_idx: list[list[int]] = []
833+
input_token_logprobs_val: list[list[float]] = []
834+
input_token_logprobs_idx: list[list[int]] = []
835+
input_top_logprobs_val: list = []
836+
input_top_logprobs_idx: list = []
837+
input_token_ids_logprobs_val: list = []
838+
input_token_ids_logprobs_idx: list = []
722839

723840
for i, rs in enumerate(output_states):
724841
# For finished requests, always output (unless already output)
@@ -798,6 +915,49 @@ def stream_output(
798915
output_token_logprobs_val.append([])
799916
output_token_logprobs_idx.append([])
800917

918+
# Prompt/input-token logprobs: ship the un-shipped tail of the
919+
# accumulated prompt logprobs (accumulated across chunked prefill in
920+
# post_process_forward_op). Tracked with sent_prompt_logprob_offset
921+
# so later decode-step streams don't resend the prompt logprobs.
922+
if rs.return_prompt_logprob and rs.input_token_logprobs_val is not None:
923+
off = rs.sent_prompt_logprob_offset
924+
input_token_logprobs_val.append(rs.input_token_logprobs_val[off:])
925+
input_token_logprobs_idx.append(rs.input_token_logprobs_idx[off:])
926+
# Top-k prompt logprobs are parallel (one [k] list per position),
927+
# so ship the same un-shipped tail using the same offset.
928+
if rs.input_top_logprobs_val is not None:
929+
input_top_logprobs_val.append(rs.input_top_logprobs_val[off:])
930+
input_top_logprobs_idx.append(
931+
rs.input_top_logprobs_idx[off:]
932+
if rs.input_top_logprobs_idx is not None
933+
else []
934+
)
935+
else:
936+
input_top_logprobs_val.append([])
937+
input_top_logprobs_idx.append([])
938+
# Token-id prompt logprobs are parallel (one [num_token_ids] list
939+
# per position), so ship the same un-shipped tail using off.
940+
if rs.input_token_ids_logprobs_val is not None:
941+
input_token_ids_logprobs_val.append(
942+
rs.input_token_ids_logprobs_val[off:]
943+
)
944+
input_token_ids_logprobs_idx.append(
945+
rs.input_token_ids_logprobs_idx[off:]
946+
if rs.input_token_ids_logprobs_idx is not None
947+
else []
948+
)
949+
else:
950+
input_token_ids_logprobs_val.append([])
951+
input_token_ids_logprobs_idx.append([])
952+
rs.sent_prompt_logprob_offset = len(rs.input_token_logprobs_val)
953+
else:
954+
input_token_logprobs_val.append([])
955+
input_token_logprobs_idx.append([])
956+
input_top_logprobs_val.append([])
957+
input_top_logprobs_idx.append([])
958+
input_token_ids_logprobs_val.append([])
959+
input_token_ids_logprobs_idx.append([])
960+
801961
# Don't send empty batch to detokenizer
802962
if len(rids_to_send) == 0:
803963
return
@@ -817,16 +977,23 @@ def stream_output(
817977
completion_tokens=completion_tokens,
818978
cached_tokens=cached_tokens,
819979
spec_verify_ct=spec_verify_ct,
820-
input_token_logprobs_val=[],
821-
input_token_logprobs_idx=[],
980+
input_token_logprobs_val=input_token_logprobs_val,
981+
input_token_logprobs_idx=input_token_logprobs_idx,
822982
output_token_logprobs_val=output_token_logprobs_val,
823983
output_token_logprobs_idx=output_token_logprobs_idx,
824-
input_top_logprobs_val=[],
825-
input_top_logprobs_idx=[],
984+
input_top_logprobs_val=input_top_logprobs_val,
985+
input_top_logprobs_idx=input_top_logprobs_idx,
986+
# TODO(logprobs): OUTPUT top-k / token-id logprobs (logprobs=N>0 or
987+
# logprob_token_ids on output) are not populated yet. Output logprobs
988+
# flow through the captured CUDA-graph decode path, so top-k must be
989+
# computed in the sampler and captured into the graph output buffers
990+
# (unlike the prompt path, which is eager/prefill). Until then,
991+
# logprobs=N>0 returns only the sampled token's logprob per position
992+
# (its value is correct; the top-N alternatives are absent).
826993
output_top_logprobs_val=[],
827994
output_top_logprobs_idx=[],
828-
input_token_ids_logprobs_val=[],
829-
input_token_ids_logprobs_idx=[],
995+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
996+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
830997
output_token_ids_logprobs_val=[],
831998
output_token_ids_logprobs_idx=[],
832999
output_hidden_states=[],

python/tokenspeed/runtime/engine/input_processor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ async def tokenize_one_request(
173173
input_ids = pad_input_tokens(list(input_ids), multimodal_inputs)
174174

175175
if self.engine.is_generation:
176-
return_logprob = obj.return_logprob
177-
logprob_start_len = obj.logprob_start_len
178-
top_logprobs_num = obj.top_logprobs_num
179-
token_ids_logprob = obj.token_ids_logprob
176+
logprob_params = obj.logprob_params
177+
if logprob_params is not None:
178+
logprob_params.verify(
179+
vocab_size=self.engine.model_config.vocab_size,
180+
max_logprobs=self.engine.model_config.vocab_size,
181+
)
180182
session_params = (
181183
SessionParams(**obj.session_params) if obj.session_params else None
182184
)
@@ -218,10 +220,7 @@ async def tokenize_one_request(
218220
input_text,
219221
input_ids,
220222
sampling_params,
221-
return_logprob,
222-
logprob_start_len,
223-
top_logprobs_num,
224-
token_ids_logprob,
223+
logprob_params,
225224
obj.stream,
226225
bootstrap_host=obj.bootstrap_host,
227226
bootstrap_port=obj.bootstrap_port,

0 commit comments

Comments
 (0)