@@ -74,6 +74,7 @@ def __init__(
7474 token_ids_logprob : list [int ] | None = None ,
7575 multimodal_inputs = None ,
7676 prompt_input_ids_unpadded : list [int ] | None = None ,
77+ return_prompt_logprob : bool = False ,
7778 ) -> None :
7879 # --- Extracted from recv_req (immutable) ---
7980 self .prompt_input_ids : list [int ] = prompt_input_ids
@@ -91,6 +92,7 @@ def __init__(
9192 self .return_logprob = return_logprob
9293 self .top_logprobs_num = top_logprobs_num
9394 self .token_ids_logprob = token_ids_logprob
95+ self .return_prompt_logprob = return_prompt_logprob
9496
9597 # --- generation state (updated with forward step) ---
9698 self .output_ids : list [int ] = []
@@ -107,6 +109,30 @@ def __init__(
107109 self .output_token_logprobs_idx : list [int ] | None = (
108110 [] if return_logprob else None
109111 )
112+ # Prompt/input-token logprobs, accumulated across (possibly chunked)
113+ # prefill ops. None unless the request asked for prompt_logprobs.
114+ self .input_token_logprobs_val : list [float ] | None = (
115+ [] if return_prompt_logprob else None
116+ )
117+ self .input_token_logprobs_idx : list [int ] | None = (
118+ [] if return_prompt_logprob else None
119+ )
120+ # Per-position top-k prompt logprobs (prompt_logprobs>0). Each entry is
121+ # a [k] list parallel to the base prompt-logprob accumulators above.
122+ self .input_top_logprobs_val : list | None = [] if return_prompt_logprob else None
123+ self .input_top_logprobs_idx : list | None = [] if return_prompt_logprob else None
124+ # Per-position token-id prompt logprobs (logprob_token_ids requested).
125+ # Each entry is a [num_token_ids] list parallel to the base prompt-logprob
126+ # accumulators above.
127+ self .input_token_ids_logprobs_val : list | None = (
128+ [] if return_prompt_logprob else None
129+ )
130+ self .input_token_ids_logprobs_idx : list | None = (
131+ [] if return_prompt_logprob else None
132+ )
133+ # Number of prompt-logprob entries already shipped to the detokenizer,
134+ # so multi-chunk / multi-step streaming only sends the un-shipped tail.
135+ self .sent_prompt_logprob_offset : int = 0
110136
111137 # --- Streaming bookkeeping (internal) ---
112138 self ._surr_offset : int | None = None
@@ -144,17 +170,27 @@ def from_recv_req(
144170 tokenizer ,
145171 eos_token_ids : list [int ],
146172 ) -> RequestState :
173+ # Translate the LogprobParams into the scheduler-internal
174+ # scalar attrs the rest of the pipeline already understands. Output-token
175+ # logprob collection is gated on ``return_logprob`` below; the prompt
176+ # (off-policy) path is wired separately in full Phase B.
177+ lp = getattr (recv_req , "logprob_params" , None )
178+ return_logprob = lp is not None and lp .num_logprobs () is not None
179+ top_logprobs_num = lp .num_logprobs () or 0 if return_logprob else 0
180+ token_ids_logprob = lp .logprob_token_ids if lp is not None else None
181+ return_prompt_logprob = lp is not None and lp .num_prompt_logprobs () is not None
147182 return cls (
148183 prompt_input_ids = recv_req .input_ids ,
149184 sampling_params = recv_req .sampling_params ,
150185 stream = recv_req .stream ,
151186 tokenizer = tokenizer ,
152187 eos_token_ids = eos_token_ids ,
153- return_logprob = getattr ( recv_req , " return_logprob" , False ) ,
154- top_logprobs_num = getattr ( recv_req , " top_logprobs_num" , 0 ) ,
155- token_ids_logprob = getattr ( recv_req , " token_ids_logprob" , None ) ,
188+ return_logprob = return_logprob ,
189+ top_logprobs_num = top_logprobs_num ,
190+ token_ids_logprob = token_ids_logprob ,
156191 multimodal_inputs = getattr (recv_req , "multimodal_inputs" , None ),
157192 prompt_input_ids_unpadded = getattr (recv_req , "input_ids_unpadded" , None ),
193+ return_prompt_logprob = return_prompt_logprob ,
158194 )
159195
160196 @property
@@ -521,6 +557,31 @@ def post_process_forward_op(
521557 if model_execution_results .output_logprobs is not None
522558 else None
523559 )
560+ _input_lp = getattr (model_execution_results , "input_token_logprobs" , None )
561+ input_logprobs_list = _input_lp .tolist () if _input_lp is not None else None
562+ _input_lp_ids = getattr (
563+ model_execution_results , "input_logprob_token_ids" , None
564+ )
565+ input_logprob_ids_list = (
566+ _input_lp_ids .tolist () if _input_lp_ids is not None else None
567+ )
568+ # Top-k prompt logprobs are already CPU Python lists partitioned per
569+ # extend request (index by extend request i, NOT the flat ilp_pt).
570+ input_top_val_list = getattr (
571+ model_execution_results , "input_top_logprobs_val" , None
572+ )
573+ input_top_idx_list = getattr (
574+ model_execution_results , "input_top_logprobs_idx" , None
575+ )
576+ # Token-id prompt logprobs are likewise CPU Python lists partitioned per
577+ # extend request (index by extend request i, NOT the flat ilp_pt).
578+ input_tid_val_list = getattr (
579+ model_execution_results , "input_token_ids_logprobs_val" , None
580+ )
581+ input_tid_idx_list = getattr (
582+ model_execution_results , "input_token_ids_logprobs_idx" , None
583+ )
584+ ilp_pt = 0
524585 pt = 0
525586 for i , rid in enumerate (forward_op .request_ids ):
526587 output_length = model_execution_results .output_lengths [i ].item ()
@@ -538,12 +599,62 @@ def post_process_forward_op(
538599 else :
539600 pt += output_length
540601
602+ # Prompt/input-token logprobs (pure-extend prompt_logprobs path).
603+ # The flat array carries one entry per scored prefill position for
604+ # every i < num_extends (the ctx activates input logprobs for the
605+ # whole pure-extend batch), so advance ilp_pt unconditionally here —
606+ # before the rid_to_state / prefill-finished guards — to keep the
607+ # pointer aligned with the flat array even for delayed/chunked rows.
608+ seg_val = None
609+ seg_idx = None
610+ pl_req = - 1
611+ if input_logprobs_list is not None and i < num_extends :
612+ plen = int (forward_op .input_lengths [i ])
613+ pl_req = int (forward_op .prompt_logprobs [i ])
614+ seg_val = input_logprobs_list [ilp_pt : ilp_pt + plen ]
615+ seg_idx = (
616+ input_logprob_ids_list [ilp_pt : ilp_pt + plen ]
617+ if input_logprob_ids_list is not None
618+ else []
619+ )
620+ ilp_pt += plen
621+
541622 if rid not in self .rid_to_state :
542623 # means it's delayed token, do not process
543624 continue
544625
545626 request_state : RequestState = self .rid_to_state [rid ]
546627
628+ if (
629+ seg_val is not None
630+ and pl_req >= 0
631+ and request_state .input_token_logprobs_val is not None
632+ ):
633+ request_state .input_token_logprobs_val .extend (seg_val )
634+ request_state .input_token_logprobs_idx .extend (seg_idx )
635+ if (
636+ input_top_val_list is not None
637+ and i < len (input_top_val_list )
638+ and request_state .input_top_logprobs_val is not None
639+ ):
640+ request_state .input_top_logprobs_val .extend (input_top_val_list [i ])
641+ if input_top_idx_list is not None and i < len (input_top_idx_list ):
642+ request_state .input_top_logprobs_idx .extend (
643+ input_top_idx_list [i ]
644+ )
645+ if (
646+ input_tid_val_list is not None
647+ and i < len (input_tid_val_list )
648+ and request_state .input_token_ids_logprobs_val is not None
649+ ):
650+ request_state .input_token_ids_logprobs_val .extend (
651+ input_tid_val_list [i ]
652+ )
653+ if input_tid_idx_list is not None and i < len (input_tid_idx_list ):
654+ request_state .input_token_ids_logprobs_idx .extend (
655+ input_tid_idx_list [i ]
656+ )
657+
547658 # Do not output chunking result
548659 if not request_state .prefill_finished :
549660 continue
@@ -719,6 +830,12 @@ def stream_output(
719830 output_extra_infos : list [dict ] = []
720831 output_token_logprobs_val : list [list [float ]] = []
721832 output_token_logprobs_idx : list [list [int ]] = []
833+ input_token_logprobs_val : list [list [float ]] = []
834+ input_token_logprobs_idx : list [list [int ]] = []
835+ input_top_logprobs_val : list = []
836+ input_top_logprobs_idx : list = []
837+ input_token_ids_logprobs_val : list = []
838+ input_token_ids_logprobs_idx : list = []
722839
723840 for i , rs in enumerate (output_states ):
724841 # For finished requests, always output (unless already output)
@@ -798,6 +915,49 @@ def stream_output(
798915 output_token_logprobs_val .append ([])
799916 output_token_logprobs_idx .append ([])
800917
918+ # Prompt/input-token logprobs: ship the un-shipped tail of the
919+ # accumulated prompt logprobs (accumulated across chunked prefill in
920+ # post_process_forward_op). Tracked with sent_prompt_logprob_offset
921+ # so later decode-step streams don't resend the prompt logprobs.
922+ if rs .return_prompt_logprob and rs .input_token_logprobs_val is not None :
923+ off = rs .sent_prompt_logprob_offset
924+ input_token_logprobs_val .append (rs .input_token_logprobs_val [off :])
925+ input_token_logprobs_idx .append (rs .input_token_logprobs_idx [off :])
926+ # Top-k prompt logprobs are parallel (one [k] list per position),
927+ # so ship the same un-shipped tail using the same offset.
928+ if rs .input_top_logprobs_val is not None :
929+ input_top_logprobs_val .append (rs .input_top_logprobs_val [off :])
930+ input_top_logprobs_idx .append (
931+ rs .input_top_logprobs_idx [off :]
932+ if rs .input_top_logprobs_idx is not None
933+ else []
934+ )
935+ else :
936+ input_top_logprobs_val .append ([])
937+ input_top_logprobs_idx .append ([])
938+ # Token-id prompt logprobs are parallel (one [num_token_ids] list
939+ # per position), so ship the same un-shipped tail using off.
940+ if rs .input_token_ids_logprobs_val is not None :
941+ input_token_ids_logprobs_val .append (
942+ rs .input_token_ids_logprobs_val [off :]
943+ )
944+ input_token_ids_logprobs_idx .append (
945+ rs .input_token_ids_logprobs_idx [off :]
946+ if rs .input_token_ids_logprobs_idx is not None
947+ else []
948+ )
949+ else :
950+ input_token_ids_logprobs_val .append ([])
951+ input_token_ids_logprobs_idx .append ([])
952+ rs .sent_prompt_logprob_offset = len (rs .input_token_logprobs_val )
953+ else :
954+ input_token_logprobs_val .append ([])
955+ input_token_logprobs_idx .append ([])
956+ input_top_logprobs_val .append ([])
957+ input_top_logprobs_idx .append ([])
958+ input_token_ids_logprobs_val .append ([])
959+ input_token_ids_logprobs_idx .append ([])
960+
801961 # Don't send empty batch to detokenizer
802962 if len (rids_to_send ) == 0 :
803963 return
@@ -817,16 +977,23 @@ def stream_output(
817977 completion_tokens = completion_tokens ,
818978 cached_tokens = cached_tokens ,
819979 spec_verify_ct = spec_verify_ct ,
820- input_token_logprobs_val = [] ,
821- input_token_logprobs_idx = [] ,
980+ input_token_logprobs_val = input_token_logprobs_val ,
981+ input_token_logprobs_idx = input_token_logprobs_idx ,
822982 output_token_logprobs_val = output_token_logprobs_val ,
823983 output_token_logprobs_idx = output_token_logprobs_idx ,
824- input_top_logprobs_val = [],
825- input_top_logprobs_idx = [],
984+ input_top_logprobs_val = input_top_logprobs_val ,
985+ input_top_logprobs_idx = input_top_logprobs_idx ,
986+ # TODO(logprobs): OUTPUT top-k / token-id logprobs (logprobs=N>0 or
987+ # logprob_token_ids on output) are not populated yet. Output logprobs
988+ # flow through the captured CUDA-graph decode path, so top-k must be
989+ # computed in the sampler and captured into the graph output buffers
990+ # (unlike the prompt path, which is eager/prefill). Until then,
991+ # logprobs=N>0 returns only the sampled token's logprob per position
992+ # (its value is correct; the top-N alternatives are absent).
826993 output_top_logprobs_val = [],
827994 output_top_logprobs_idx = [],
828- input_token_ids_logprobs_val = [] ,
829- input_token_ids_logprobs_idx = [] ,
995+ input_token_ids_logprobs_val = input_token_ids_logprobs_val ,
996+ input_token_ids_logprobs_idx = input_token_ids_logprobs_idx ,
830997 output_token_ids_logprobs_val = [],
831998 output_token_ids_logprobs_idx = [],
832999 output_hidden_states = [],
0 commit comments