Skip to content

Commit ca4162c

Browse files
committed
[Bugfix][MIMO Audio] Restore MAX_CODE2WAV_TOKENS cap and zero-row filter in async_chunk
The old llm2code2wav_full_payload truncated flat_codes at MAX_CODE2WAV_TOKENS and filtered zero-padded codec rows via _filter_zero_codec_rows before flattening. Both guards were lost when the function was replaced with a delegation to llm2code2wav_async_chunk. Restore them and drop the tensor-list-tensor round-trip. Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Nick Cao <ncao@redhat.com>
1 parent 3647324 commit ca4162c

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

vllm_omni/model_executor/stage_input_processors/mimo_audio.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,6 @@ def _flush_remaining_codes(
114114
)
115115

116116

117-
def _is_codes_empty(codes: Any) -> bool:
118-
"""Check whether code_predictor_codes should be treated as empty / invalid."""
119-
if codes is None:
120-
return True
121-
if isinstance(codes, torch.Tensor):
122-
return codes.numel() == 0 or not codes.any()
123-
if hasattr(codes, "__len__") and len(codes) == 0:
124-
return True
125-
t = torch.tensor(codes, dtype=torch.long) if not isinstance(codes, torch.Tensor) else codes
126-
return not t.any()
127-
128-
129117
def _to_code_tensor(codes: Any) -> torch.Tensor | None:
130118
"""Convert codes to a (B, 1, 8, 4) long tensor, or return None if shape is invalid."""
131119
code_tensor = codes.to(torch.long) if isinstance(codes, torch.Tensor) else torch.tensor(codes, dtype=torch.long)
@@ -208,6 +196,12 @@ def llm2code2wav_async_chunk(
208196
return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size)
209197
return None
210198

199+
code_tensor = _filter_zero_codec_rows(code_tensor)
200+
if code_tensor.numel() == 0:
201+
if is_finished:
202+
return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size)
203+
return None
204+
211205
pad_vec = torch.tensor([TALKER_CODEC_PAD_TOKEN_ID] * 4, device=code_tensor.device, dtype=code_tensor.dtype)
212206
code_list = prepend_and_flatten_colmajor(code_tensor, pad_vec).tolist()
213207

@@ -223,15 +217,17 @@ def llm2code2wav_async_chunk(
223217
context_length = chunk_length if chunk_length != 0 else chunk_size
224218
end_index = min(length, left_context_size + context_length)
225219
left_ctx_frames = max(0, min(length - context_length, left_context_size))
226-
flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist()
220+
flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1)
221+
if flat_codes.numel() > MAX_CODE2WAV_TOKENS:
222+
flat_codes = flat_codes[:MAX_CODE2WAV_TOKENS]
227223

228224
return OmniPayloadStruct(
229-
codes=CodesStruct(audio=torch.tensor(flat_codes)),
225+
codes=CodesStruct(audio=flat_codes),
230226
meta=MetaStruct(
231227
left_context_size=left_ctx_frames,
232228
codec_chunk_frames=chunk_size,
233229
codec_left_context_frames=left_context_size,
234-
code_flat_numel=len(flat_codes),
230+
code_flat_numel=int(flat_codes.numel()),
235231
finished=torch.tensor(is_finished, dtype=torch.bool),
236232
),
237233
)

0 commit comments

Comments
 (0)