@@ -74,6 +74,7 @@ def __init__(
7474 token_ids_logprob : list [int ] | None = None ,
7575 multimodal_inputs = None ,
7676 prompt_input_ids_unpadded : list [int ] | None = None ,
77+ return_prompt_logprob : bool = False ,
7778 ) -> None :
7879 # --- Extracted from recv_req (immutable) ---
7980 self .prompt_input_ids : list [int ] = prompt_input_ids
@@ -91,6 +92,7 @@ def __init__(
9192 self .return_logprob = return_logprob
9293 self .top_logprobs_num = top_logprobs_num
9394 self .token_ids_logprob = token_ids_logprob
95+ self .return_prompt_logprob = return_prompt_logprob
9496
9597 # --- generation state (updated with forward step) ---
9698 self .output_ids : list [int ] = []
@@ -107,6 +109,30 @@ def __init__(
107109 self .output_token_logprobs_idx : list [int ] | None = (
108110 [] if return_logprob else None
109111 )
112+ # Prompt/input-token logprobs, accumulated across (possibly chunked)
113+ # prefill ops. None unless the request asked for prompt_logprobs.
114+ self .input_token_logprobs_val : list [float ] | None = (
115+ [] if return_prompt_logprob else None
116+ )
117+ self .input_token_logprobs_idx : list [int ] | None = (
118+ [] if return_prompt_logprob else None
119+ )
120+ # Per-position top-k prompt logprobs (prompt_logprobs>0). Each entry is
121+ # a [k] list parallel to the base prompt-logprob accumulators above.
122+ self .input_top_logprobs_val : list | None = [] if return_prompt_logprob else None
123+ self .input_top_logprobs_idx : list | None = [] if return_prompt_logprob else None
124+ # Per-position token-id prompt logprobs (logprob_token_ids requested).
125+ # Each entry is a [num_token_ids] list parallel to the base prompt-logprob
126+ # accumulators above.
127+ self .input_token_ids_logprobs_val : list | None = (
128+ [] if return_prompt_logprob else None
129+ )
130+ self .input_token_ids_logprobs_idx : list | None = (
131+ [] if return_prompt_logprob else None
132+ )
133+ # Number of prompt-logprob entries already shipped to the detokenizer,
134+ # so multi-chunk / multi-step streaming only sends the un-shipped tail.
135+ self .sent_prompt_logprob_offset : int = 0
110136
111137 # --- Streaming bookkeeping (internal) ---
112138 self ._surr_offset : int | None = None
@@ -152,17 +178,27 @@ def from_recv_req(
152178 tokenizer ,
153179 eos_token_ids : list [int ],
154180 ) -> RequestState :
181+ # Translate the LogprobParams into the scheduler-internal
182+ # scalar attrs the rest of the pipeline already understands. Output-token
183+ # logprob collection is gated on ``return_logprob`` below; the prompt
184+ # (off-policy) path is wired separately in full Phase B.
185+ lp = getattr (recv_req , "logprob_params" , None )
186+ return_logprob = lp is not None and lp .num_logprobs () is not None
187+ top_logprobs_num = lp .num_logprobs () or 0 if return_logprob else 0
188+ token_ids_logprob = lp .logprob_token_ids if lp is not None else None
189+ return_prompt_logprob = lp is not None and lp .num_prompt_logprobs () is not None
155190 return cls (
156191 prompt_input_ids = recv_req .input_ids ,
157192 sampling_params = recv_req .sampling_params ,
158193 stream = recv_req .stream ,
159194 tokenizer = tokenizer ,
160195 eos_token_ids = eos_token_ids ,
161- return_logprob = getattr ( recv_req , " return_logprob" , False ) ,
162- top_logprobs_num = getattr ( recv_req , " top_logprobs_num" , 0 ) ,
163- token_ids_logprob = getattr ( recv_req , " token_ids_logprob" , None ) ,
196+ return_logprob = return_logprob ,
197+ top_logprobs_num = top_logprobs_num ,
198+ token_ids_logprob = token_ids_logprob ,
164199 multimodal_inputs = getattr (recv_req , "multimodal_inputs" , None ),
165200 prompt_input_ids_unpadded = getattr (recv_req , "input_ids_unpadded" , None ),
201+ return_prompt_logprob = return_prompt_logprob ,
166202 )
167203
168204 @property
@@ -534,6 +570,31 @@ def post_process_forward_op(
534570 if model_execution_results .output_logprobs is not None
535571 else None
536572 )
573+ _input_lp = getattr (model_execution_results , "input_token_logprobs" , None )
574+ input_logprobs_list = _input_lp .tolist () if _input_lp is not None else None
575+ _input_lp_ids = getattr (
576+ model_execution_results , "input_logprob_token_ids" , None
577+ )
578+ input_logprob_ids_list = (
579+ _input_lp_ids .tolist () if _input_lp_ids is not None else None
580+ )
581+ # Top-k prompt logprobs are already CPU Python lists partitioned per
582+ # extend request (index by extend request i, NOT the flat ilp_pt).
583+ input_top_val_list = getattr (
584+ model_execution_results , "input_top_logprobs_val" , None
585+ )
586+ input_top_idx_list = getattr (
587+ model_execution_results , "input_top_logprobs_idx" , None
588+ )
589+ # Token-id prompt logprobs are likewise CPU Python lists partitioned per
590+ # extend request (index by extend request i, NOT the flat ilp_pt).
591+ input_tid_val_list = getattr (
592+ model_execution_results , "input_token_ids_logprobs_val" , None
593+ )
594+ input_tid_idx_list = getattr (
595+ model_execution_results , "input_token_ids_logprobs_idx" , None
596+ )
597+ ilp_pt = 0
537598 pt = 0
538599 for i , rid in enumerate (forward_op .request_ids ):
539600 output_length = model_execution_results .output_lengths [i ].item ()
@@ -551,12 +612,62 @@ def post_process_forward_op(
551612 else :
552613 pt += output_length
553614
615+ # Prompt/input-token logprobs (pure-extend prompt_logprobs path).
616+ # The flat array carries one entry per scored prefill position for
617+ # every i < num_extends (the ctx activates input logprobs for the
618+ # whole pure-extend batch), so advance ilp_pt unconditionally here —
619+ # before the rid_to_state / prefill-finished guards — to keep the
620+ # pointer aligned with the flat array even for delayed/chunked rows.
621+ seg_val = None
622+ seg_idx = None
623+ pl_req = - 1
624+ if input_logprobs_list is not None and i < num_extends :
625+ plen = int (forward_op .input_lengths [i ])
626+ pl_req = int (forward_op .prompt_logprobs [i ])
627+ seg_val = input_logprobs_list [ilp_pt : ilp_pt + plen ]
628+ seg_idx = (
629+ input_logprob_ids_list [ilp_pt : ilp_pt + plen ]
630+ if input_logprob_ids_list is not None
631+ else []
632+ )
633+ ilp_pt += plen
634+
554635 if rid not in self .rid_to_state :
555636 # means it's delayed token, do not process
556637 continue
557638
558639 request_state : RequestState = self .rid_to_state [rid ]
559640
641+ if (
642+ seg_val is not None
643+ and pl_req >= 0
644+ and request_state .input_token_logprobs_val is not None
645+ ):
646+ request_state .input_token_logprobs_val .extend (seg_val )
647+ request_state .input_token_logprobs_idx .extend (seg_idx )
648+ if (
649+ input_top_val_list is not None
650+ and i < len (input_top_val_list )
651+ and request_state .input_top_logprobs_val is not None
652+ ):
653+ request_state .input_top_logprobs_val .extend (input_top_val_list [i ])
654+ if input_top_idx_list is not None and i < len (input_top_idx_list ):
655+ request_state .input_top_logprobs_idx .extend (
656+ input_top_idx_list [i ]
657+ )
658+ if (
659+ input_tid_val_list is not None
660+ and i < len (input_tid_val_list )
661+ and request_state .input_token_ids_logprobs_val is not None
662+ ):
663+ request_state .input_token_ids_logprobs_val .extend (
664+ input_tid_val_list [i ]
665+ )
666+ if input_tid_idx_list is not None and i < len (input_tid_idx_list ):
667+ request_state .input_token_ids_logprobs_idx .extend (
668+ input_tid_idx_list [i ]
669+ )
670+
560671 # Do not output chunking result
561672 if not request_state .prefill_finished :
562673 continue
@@ -738,6 +849,12 @@ def stream_output(
738849 output_extra_infos : list [dict ] = []
739850 output_token_logprobs_val : list [list [float ]] = []
740851 output_token_logprobs_idx : list [list [int ]] = []
852+ input_token_logprobs_val : list [list [float ]] = []
853+ input_token_logprobs_idx : list [list [int ]] = []
854+ input_top_logprobs_val : list = []
855+ input_top_logprobs_idx : list = []
856+ input_token_ids_logprobs_val : list = []
857+ input_token_ids_logprobs_idx : list = []
741858
742859 for i , rs in enumerate (output_states ):
743860 # For finished requests, always output (unless already output)
@@ -817,6 +934,49 @@ def stream_output(
817934 output_token_logprobs_val .append ([])
818935 output_token_logprobs_idx .append ([])
819936
937+ # Prompt/input-token logprobs: ship the un-shipped tail of the
938+ # accumulated prompt logprobs (accumulated across chunked prefill in
939+ # post_process_forward_op). Tracked with sent_prompt_logprob_offset
940+ # so later decode-step streams don't resend the prompt logprobs.
941+ if rs .return_prompt_logprob and rs .input_token_logprobs_val is not None :
942+ off = rs .sent_prompt_logprob_offset
943+ input_token_logprobs_val .append (rs .input_token_logprobs_val [off :])
944+ input_token_logprobs_idx .append (rs .input_token_logprobs_idx [off :])
945+ # Top-k prompt logprobs are parallel (one [k] list per position),
946+ # so ship the same un-shipped tail using the same offset.
947+ if rs .input_top_logprobs_val is not None :
948+ input_top_logprobs_val .append (rs .input_top_logprobs_val [off :])
949+ input_top_logprobs_idx .append (
950+ rs .input_top_logprobs_idx [off :]
951+ if rs .input_top_logprobs_idx is not None
952+ else []
953+ )
954+ else :
955+ input_top_logprobs_val .append ([])
956+ input_top_logprobs_idx .append ([])
957+ # Token-id prompt logprobs are parallel (one [num_token_ids] list
958+ # per position), so ship the same un-shipped tail using off.
959+ if rs .input_token_ids_logprobs_val is not None :
960+ input_token_ids_logprobs_val .append (
961+ rs .input_token_ids_logprobs_val [off :]
962+ )
963+ input_token_ids_logprobs_idx .append (
964+ rs .input_token_ids_logprobs_idx [off :]
965+ if rs .input_token_ids_logprobs_idx is not None
966+ else []
967+ )
968+ else :
969+ input_token_ids_logprobs_val .append ([])
970+ input_token_ids_logprobs_idx .append ([])
971+ rs .sent_prompt_logprob_offset = len (rs .input_token_logprobs_val )
972+ else :
973+ input_token_logprobs_val .append ([])
974+ input_token_logprobs_idx .append ([])
975+ input_top_logprobs_val .append ([])
976+ input_top_logprobs_idx .append ([])
977+ input_token_ids_logprobs_val .append ([])
978+ input_token_ids_logprobs_idx .append ([])
979+
820980 # Don't send empty batch to detokenizer
821981 if len (rids_to_send ) == 0 :
822982 return
@@ -836,16 +996,23 @@ def stream_output(
836996 completion_tokens = completion_tokens ,
837997 cached_tokens = cached_tokens ,
838998 spec_verify_ct = spec_verify_ct ,
839- input_token_logprobs_val = [] ,
840- input_token_logprobs_idx = [] ,
999+ input_token_logprobs_val = input_token_logprobs_val ,
1000+ input_token_logprobs_idx = input_token_logprobs_idx ,
8411001 output_token_logprobs_val = output_token_logprobs_val ,
8421002 output_token_logprobs_idx = output_token_logprobs_idx ,
843- input_top_logprobs_val = [],
844- input_top_logprobs_idx = [],
1003+ input_top_logprobs_val = input_top_logprobs_val ,
1004+ input_top_logprobs_idx = input_top_logprobs_idx ,
1005+ # TODO(logprobs): OUTPUT top-k / token-id logprobs (logprobs=N>0 or
1006+ # logprob_token_ids on output) are not populated yet. Output logprobs
1007+ # flow through the captured CUDA-graph decode path, so top-k must be
1008+ # computed in the sampler and captured into the graph output buffers
1009+ # (unlike the prompt path, which is eager/prefill). Until then,
1010+ # logprobs=N>0 returns only the sampled token's logprob per position
1011+ # (its value is correct; the top-N alternatives are absent).
8451012 output_top_logprobs_val = [],
8461013 output_top_logprobs_idx = [],
847- input_token_ids_logprobs_val = [] ,
848- input_token_ids_logprobs_idx = [] ,
1014+ input_token_ids_logprobs_val = input_token_ids_logprobs_val ,
1015+ input_token_ids_logprobs_idx = input_token_ids_logprobs_idx ,
8491016 output_token_ids_logprobs_val = [],
8501017 output_token_ids_logprobs_idx = [],
8511018 output_hidden_states = [],
0 commit comments