Skip to content

Commit 740d514

Browse files
HJSangclaude
andcommitted
feat(logprobs): vLLM-style output logprobs (LogprobParams), spec-decode support
Add a dedicated LogprobParams request struct and a vLLM-style Logprob output shape (per-position {token_id: Logprob}), kept separate from SamplingParams. Scope to OUTPUT logprobs only for now: LogprobParams.verify() is the single gate and loudly rejects the not-yet-correct surface — prompt_logprobs and logprob_token_ids (prompt path is only valid for single-chunk pure-extend prefill; chunked/mixed/prefix-cache paths would be silently wrong), output top-k (logprobs>0, only the sampled token is materialized), and full-vocab (-1). Only logprobs in {None, 0} are honored. GPU parity runner updated to request logprobs=0. Also enable output logprobs under speculative decoding: Engine.generate / async_generate previously nulled all logprob requests whenever a spec algorithm was set, silently dropping them. The engine computes correct, accept-length-aligned output logprobs on the spec verify path, so the guard was overly conservative; remove it and rely on verify() as the gate. Back-compat / streaming hardening: - Legacy field coercion (io_struct) now builds a per-row LogprobParams for batched list inputs (e.g. return_logprob=[False, True]) instead of collapsing to row 0, and clamps legacy top_logprobs_num>0 to the sampled-token logprob (logprobs=0) rather than erroring. - RequestOutputCollector now sums cumulative_logprob across coalesced streamed frames (each frame's value is a per-frame delta) so it stays consistent with the appended per-position logprobs. Validated on B200 (Qwen2-1.5B and Qwen3.5-397B-A17B-NVFP4 MTP/tp4): output logprobs=0 returns finite, <=0 per-token logprobs; prompt/top-k surface rejected at the request entrypoint. Rebased branch rebuilds (scheduler C++ + Python import + engine run) on top of current main. Note: the OpenAI HTTP serving path (ts serve -> smg_grpc_servicer) maps logprobs in a separate, out-of-repo package and needs the matching change there; this covers the in-repo Engine/SDK path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Hejian Sang <sanghj0923@gmail.com>
1 parent 4d0d32c commit 740d514

27 files changed

Lines changed: 987 additions & 339 deletions

python/tokenspeed/runtime/engine/collector.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,29 @@
2525
from collections.abc import Sequence
2626
from typing import Any
2727

28+
# Streaming merge policy for the logprob meta_info fields.
29+
#
30+
# ``logprobs`` and ``prompt_logprobs`` are each a list[dict[int, Logprob]] (one
31+
# entry per token position) that GROWS as frames arrive, so they must be
32+
# appended rather than overwritten -- hence they are listed here and merged by
33+
# ``_extend_sequence``. That helper is dict-safe: its prefix check
34+
# (``_is_prefix``) compares elements only with ``==``, which both ``dict`` and
35+
# the ``Logprob`` dataclass support, so the entries never need to be hashed or
36+
# ordered.
37+
#
38+
# ``cumulative_logprob`` is a scalar handled separately (see _SUM_META_KEYS):
39+
# under streaming each frame recomputes it from a fresh dict, so each frame's
40+
# value is the sum of only that frame's positions (a delta). Since the
41+
# per-position ``logprobs`` are appended across frames, the scalar must be
42+
# *summed* (not overwritten) to stay consistent with the appended list.
2843
_APPEND_META_KEYS = {
29-
"input_token_logprobs",
30-
"output_token_logprobs",
31-
"input_top_logprobs",
32-
"output_top_logprobs",
33-
"input_token_ids_logprobs",
34-
"output_token_ids_logprobs",
44+
"logprobs",
45+
"prompt_logprobs",
46+
}
47+
48+
# Scalar logprob metadata accumulated by addition across coalesced frames.
49+
_SUM_META_KEYS = {
50+
"cumulative_logprob",
3551
}
3652

3753

@@ -115,6 +131,10 @@ def _merge_meta_info_into(
115131
if key in _APPEND_META_KEYS:
116132
self._extend_sequence(pending, key, value)
117133
continue
134+
if key in _SUM_META_KEYS:
135+
if value is not None:
136+
pending[key] = (pending.get(key) or 0.0) + value
137+
continue
118138
pending[key] = value
119139

120140
def _extend_sequence(self, container: dict[str, Any], key: str, value: Any) -> None:

python/tokenspeed/runtime/engine/generation_output_processor.py

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
token_ids_logprob: list[int] | None = None,
7575
multimodal_inputs=None,
7676
prompt_input_ids_unpadded: list[int] | None = None,
77+
return_prompt_logprob: bool = False,
7778
) -> None:
7879
# --- Extracted from recv_req (immutable) ---
7980
self.prompt_input_ids: list[int] = prompt_input_ids
@@ -91,6 +92,7 @@ def __init__(
9192
self.return_logprob = return_logprob
9293
self.top_logprobs_num = top_logprobs_num
9394
self.token_ids_logprob = token_ids_logprob
95+
self.return_prompt_logprob = return_prompt_logprob
9496

9597
# --- generation state (updated with forward step) ---
9698
self.output_ids: list[int] = []
@@ -107,6 +109,30 @@ def __init__(
107109
self.output_token_logprobs_idx: list[int] | None = (
108110
[] if return_logprob else None
109111
)
112+
# Prompt/input-token logprobs, accumulated across (possibly chunked)
113+
# prefill ops. None unless the request asked for prompt_logprobs.
114+
self.input_token_logprobs_val: list[float] | None = (
115+
[] if return_prompt_logprob else None
116+
)
117+
self.input_token_logprobs_idx: list[int] | None = (
118+
[] if return_prompt_logprob else None
119+
)
120+
# Per-position top-k prompt logprobs (prompt_logprobs>0). Each entry is
121+
# a [k] list parallel to the base prompt-logprob accumulators above.
122+
self.input_top_logprobs_val: list | None = [] if return_prompt_logprob else None
123+
self.input_top_logprobs_idx: list | None = [] if return_prompt_logprob else None
124+
# Per-position token-id prompt logprobs (logprob_token_ids requested).
125+
# Each entry is a [num_token_ids] list parallel to the base prompt-logprob
126+
# accumulators above.
127+
self.input_token_ids_logprobs_val: list | None = (
128+
[] if return_prompt_logprob else None
129+
)
130+
self.input_token_ids_logprobs_idx: list | None = (
131+
[] if return_prompt_logprob else None
132+
)
133+
# Number of prompt-logprob entries already shipped to the detokenizer,
134+
# so multi-chunk / multi-step streaming only sends the un-shipped tail.
135+
self.sent_prompt_logprob_offset: int = 0
110136

111137
# --- Streaming bookkeeping (internal) ---
112138
self._surr_offset: int | None = None
@@ -152,17 +178,27 @@ def from_recv_req(
152178
tokenizer,
153179
eos_token_ids: list[int],
154180
) -> RequestState:
181+
# Translate the LogprobParams into the scheduler-internal
182+
# scalar attrs the rest of the pipeline already understands. Output-token
183+
# logprob collection is gated on ``return_logprob`` below; the prompt
184+
# (off-policy) path is wired separately in full Phase B.
185+
lp = getattr(recv_req, "logprob_params", None)
186+
return_logprob = lp is not None and lp.num_logprobs() is not None
187+
top_logprobs_num = lp.num_logprobs() or 0 if return_logprob else 0
188+
token_ids_logprob = lp.logprob_token_ids if lp is not None else None
189+
return_prompt_logprob = lp is not None and lp.num_prompt_logprobs() is not None
155190
return cls(
156191
prompt_input_ids=recv_req.input_ids,
157192
sampling_params=recv_req.sampling_params,
158193
stream=recv_req.stream,
159194
tokenizer=tokenizer,
160195
eos_token_ids=eos_token_ids,
161-
return_logprob=getattr(recv_req, "return_logprob", False),
162-
top_logprobs_num=getattr(recv_req, "top_logprobs_num", 0),
163-
token_ids_logprob=getattr(recv_req, "token_ids_logprob", None),
196+
return_logprob=return_logprob,
197+
top_logprobs_num=top_logprobs_num,
198+
token_ids_logprob=token_ids_logprob,
164199
multimodal_inputs=getattr(recv_req, "multimodal_inputs", None),
165200
prompt_input_ids_unpadded=getattr(recv_req, "input_ids_unpadded", None),
201+
return_prompt_logprob=return_prompt_logprob,
166202
)
167203

168204
@property
@@ -534,6 +570,31 @@ def post_process_forward_op(
534570
if model_execution_results.output_logprobs is not None
535571
else None
536572
)
573+
_input_lp = getattr(model_execution_results, "input_token_logprobs", None)
574+
input_logprobs_list = _input_lp.tolist() if _input_lp is not None else None
575+
_input_lp_ids = getattr(
576+
model_execution_results, "input_logprob_token_ids", None
577+
)
578+
input_logprob_ids_list = (
579+
_input_lp_ids.tolist() if _input_lp_ids is not None else None
580+
)
581+
# Top-k prompt logprobs are already CPU Python lists partitioned per
582+
# extend request (index by extend request i, NOT the flat ilp_pt).
583+
input_top_val_list = getattr(
584+
model_execution_results, "input_top_logprobs_val", None
585+
)
586+
input_top_idx_list = getattr(
587+
model_execution_results, "input_top_logprobs_idx", None
588+
)
589+
# Token-id prompt logprobs are likewise CPU Python lists partitioned per
590+
# extend request (index by extend request i, NOT the flat ilp_pt).
591+
input_tid_val_list = getattr(
592+
model_execution_results, "input_token_ids_logprobs_val", None
593+
)
594+
input_tid_idx_list = getattr(
595+
model_execution_results, "input_token_ids_logprobs_idx", None
596+
)
597+
ilp_pt = 0
537598
pt = 0
538599
for i, rid in enumerate(forward_op.request_ids):
539600
output_length = model_execution_results.output_lengths[i].item()
@@ -551,12 +612,62 @@ def post_process_forward_op(
551612
else:
552613
pt += output_length
553614

615+
# Prompt/input-token logprobs (pure-extend prompt_logprobs path).
616+
# The flat array carries one entry per scored prefill position for
617+
# every i < num_extends (the ctx activates input logprobs for the
618+
# whole pure-extend batch), so advance ilp_pt unconditionally here —
619+
# before the rid_to_state / prefill-finished guards — to keep the
620+
# pointer aligned with the flat array even for delayed/chunked rows.
621+
seg_val = None
622+
seg_idx = None
623+
pl_req = -1
624+
if input_logprobs_list is not None and i < num_extends:
625+
plen = int(forward_op.input_lengths[i])
626+
pl_req = int(forward_op.prompt_logprobs[i])
627+
seg_val = input_logprobs_list[ilp_pt : ilp_pt + plen]
628+
seg_idx = (
629+
input_logprob_ids_list[ilp_pt : ilp_pt + plen]
630+
if input_logprob_ids_list is not None
631+
else []
632+
)
633+
ilp_pt += plen
634+
554635
if rid not in self.rid_to_state:
555636
# means it's delayed token, do not process
556637
continue
557638

558639
request_state: RequestState = self.rid_to_state[rid]
559640

641+
if (
642+
seg_val is not None
643+
and pl_req >= 0
644+
and request_state.input_token_logprobs_val is not None
645+
):
646+
request_state.input_token_logprobs_val.extend(seg_val)
647+
request_state.input_token_logprobs_idx.extend(seg_idx)
648+
if (
649+
input_top_val_list is not None
650+
and i < len(input_top_val_list)
651+
and request_state.input_top_logprobs_val is not None
652+
):
653+
request_state.input_top_logprobs_val.extend(input_top_val_list[i])
654+
if input_top_idx_list is not None and i < len(input_top_idx_list):
655+
request_state.input_top_logprobs_idx.extend(
656+
input_top_idx_list[i]
657+
)
658+
if (
659+
input_tid_val_list is not None
660+
and i < len(input_tid_val_list)
661+
and request_state.input_token_ids_logprobs_val is not None
662+
):
663+
request_state.input_token_ids_logprobs_val.extend(
664+
input_tid_val_list[i]
665+
)
666+
if input_tid_idx_list is not None and i < len(input_tid_idx_list):
667+
request_state.input_token_ids_logprobs_idx.extend(
668+
input_tid_idx_list[i]
669+
)
670+
560671
# Do not output chunking result
561672
if not request_state.prefill_finished:
562673
continue
@@ -738,6 +849,12 @@ def stream_output(
738849
output_extra_infos: list[dict] = []
739850
output_token_logprobs_val: list[list[float]] = []
740851
output_token_logprobs_idx: list[list[int]] = []
852+
input_token_logprobs_val: list[list[float]] = []
853+
input_token_logprobs_idx: list[list[int]] = []
854+
input_top_logprobs_val: list = []
855+
input_top_logprobs_idx: list = []
856+
input_token_ids_logprobs_val: list = []
857+
input_token_ids_logprobs_idx: list = []
741858

742859
for i, rs in enumerate(output_states):
743860
# For finished requests, always output (unless already output)
@@ -817,6 +934,49 @@ def stream_output(
817934
output_token_logprobs_val.append([])
818935
output_token_logprobs_idx.append([])
819936

937+
# Prompt/input-token logprobs: ship the un-shipped tail of the
938+
# accumulated prompt logprobs (accumulated across chunked prefill in
939+
# post_process_forward_op). Tracked with sent_prompt_logprob_offset
940+
# so later decode-step streams don't resend the prompt logprobs.
941+
if rs.return_prompt_logprob and rs.input_token_logprobs_val is not None:
942+
off = rs.sent_prompt_logprob_offset
943+
input_token_logprobs_val.append(rs.input_token_logprobs_val[off:])
944+
input_token_logprobs_idx.append(rs.input_token_logprobs_idx[off:])
945+
# Top-k prompt logprobs are parallel (one [k] list per position),
946+
# so ship the same un-shipped tail using the same offset.
947+
if rs.input_top_logprobs_val is not None:
948+
input_top_logprobs_val.append(rs.input_top_logprobs_val[off:])
949+
input_top_logprobs_idx.append(
950+
rs.input_top_logprobs_idx[off:]
951+
if rs.input_top_logprobs_idx is not None
952+
else []
953+
)
954+
else:
955+
input_top_logprobs_val.append([])
956+
input_top_logprobs_idx.append([])
957+
# Token-id prompt logprobs are parallel (one [num_token_ids] list
958+
# per position), so ship the same un-shipped tail using off.
959+
if rs.input_token_ids_logprobs_val is not None:
960+
input_token_ids_logprobs_val.append(
961+
rs.input_token_ids_logprobs_val[off:]
962+
)
963+
input_token_ids_logprobs_idx.append(
964+
rs.input_token_ids_logprobs_idx[off:]
965+
if rs.input_token_ids_logprobs_idx is not None
966+
else []
967+
)
968+
else:
969+
input_token_ids_logprobs_val.append([])
970+
input_token_ids_logprobs_idx.append([])
971+
rs.sent_prompt_logprob_offset = len(rs.input_token_logprobs_val)
972+
else:
973+
input_token_logprobs_val.append([])
974+
input_token_logprobs_idx.append([])
975+
input_top_logprobs_val.append([])
976+
input_top_logprobs_idx.append([])
977+
input_token_ids_logprobs_val.append([])
978+
input_token_ids_logprobs_idx.append([])
979+
820980
# Don't send empty batch to detokenizer
821981
if len(rids_to_send) == 0:
822982
return
@@ -836,16 +996,23 @@ def stream_output(
836996
completion_tokens=completion_tokens,
837997
cached_tokens=cached_tokens,
838998
spec_verify_ct=spec_verify_ct,
839-
input_token_logprobs_val=[],
840-
input_token_logprobs_idx=[],
999+
input_token_logprobs_val=input_token_logprobs_val,
1000+
input_token_logprobs_idx=input_token_logprobs_idx,
8411001
output_token_logprobs_val=output_token_logprobs_val,
8421002
output_token_logprobs_idx=output_token_logprobs_idx,
843-
input_top_logprobs_val=[],
844-
input_top_logprobs_idx=[],
1003+
input_top_logprobs_val=input_top_logprobs_val,
1004+
input_top_logprobs_idx=input_top_logprobs_idx,
1005+
# TODO(logprobs): OUTPUT top-k / token-id logprobs (logprobs=N>0 or
1006+
# logprob_token_ids on output) are not populated yet. Output logprobs
1007+
# flow through the captured CUDA-graph decode path, so top-k must be
1008+
# computed in the sampler and captured into the graph output buffers
1009+
# (unlike the prompt path, which is eager/prefill). Until then,
1010+
# logprobs=N>0 returns only the sampled token's logprob per position
1011+
# (its value is correct; the top-N alternatives are absent).
8451012
output_top_logprobs_val=[],
8461013
output_top_logprobs_idx=[],
847-
input_token_ids_logprobs_val=[],
848-
input_token_ids_logprobs_idx=[],
1014+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
1015+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
8491016
output_token_ids_logprobs_val=[],
8501017
output_token_ids_logprobs_idx=[],
8511018
output_hidden_states=[],

python/tokenspeed/runtime/engine/input_processor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,12 @@ async def tokenize_one_request(
173173
input_ids = pad_input_tokens(list(input_ids), multimodal_inputs)
174174

175175
if self.engine.is_generation:
176-
return_logprob = obj.return_logprob
177-
logprob_start_len = obj.logprob_start_len
178-
top_logprobs_num = obj.top_logprobs_num
179-
token_ids_logprob = obj.token_ids_logprob
176+
logprob_params = obj.logprob_params
177+
if logprob_params is not None:
178+
logprob_params.verify(
179+
vocab_size=self.engine.model_config.vocab_size,
180+
max_logprobs=self.engine.model_config.vocab_size,
181+
)
180182
session_params = (
181183
SessionParams(**obj.session_params) if obj.session_params else None
182184
)
@@ -218,10 +220,7 @@ async def tokenize_one_request(
218220
input_text,
219221
input_ids,
220222
sampling_params,
221-
return_logprob,
222-
logprob_start_len,
223-
top_logprobs_num,
224-
token_ids_logprob,
223+
logprob_params,
225224
obj.stream,
226225
bootstrap_host=obj.bootstrap_host,
227226
bootstrap_port=obj.bootstrap_port,

0 commit comments

Comments
 (0)