Skip to content

Commit d80d796

Browse files
committed
[Test][MIMO Audio] Update full_payload tests for async_chunk delegation
Provide a transfer_manager mock with code_prompt_token_ids, add req_id to request fixtures, and switch from dict access to OmniPayloadStruct attribute access to match the new return type. Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Nick Cao <ncao@redhat.com>
1 parent 304d4f9 commit d80d796

1 file changed

Lines changed: 23 additions & 11 deletions

File tree

tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,17 @@ def __init__(self, mm):
505505
assert out[0]["additional_information"] is None
506506

507507

508+
def _make_mimo_transfer_manager():
509+
"""Build a minimal transfer_manager mock for llm2code2wav_full_payload."""
510+
from collections import defaultdict
511+
from types import SimpleNamespace
512+
513+
return SimpleNamespace(
514+
connector=None,
515+
code_prompt_token_ids=defaultdict(list),
516+
)
517+
518+
508519
def test_mimo_audio_llm2code2wav_full_payload_smoke() -> None:
509520
"""Smoke: mimo_audio producer-side payload builder reads flat codes.audio + flattens."""
510521
from types import SimpleNamespace
@@ -520,19 +531,19 @@ def test_mimo_audio_llm2code2wav_full_payload_smoke() -> None:
520531
audio = torch.arange(2 * 1 * 8 * 4, dtype=torch.long).reshape(2, 1, 8, 4)
521532
audio = audio.clamp(min=1) # avoid zero-row drop
522533
pooling_output = {"codes.audio": audio}
523-
req = SimpleNamespace(output_token_ids=[])
524-
payload = llm2code2wav_full_payload(None, pooling_output, req)
534+
tm = _make_mimo_transfer_manager()
535+
req = SimpleNamespace(req_id="test-smoke")
536+
payload = llm2code2wav_full_payload(tm, pooling_output, req)
525537
assert payload is not None
526-
assert "codes" in payload and "audio" in payload["codes"]
527538
# Flattened length = numel + B*4 (per-batch pad_vec prepended by prepend_and_flatten_colmajor)
528539
batch_size = int(audio.shape[0])
529-
assert len(payload["codes"]["audio"]) == audio.numel() + batch_size * 4
540+
assert payload.codes.audio.numel() == audio.numel() + batch_size * 4
530541
# prepend_and_flatten_colmajor: PAD appears at column start in col-major flatten.
531542
# For shape [B=2, 1, 9, 4], each column has 1 PAD then 8 codec vals → PAD at indices 0, 9, 18, 27.
532-
out = payload["codes"]["audio"]
533-
assert out[0] == TALKER_CODEC_PAD_TOKEN_ID
534-
assert out[9] == TALKER_CODEC_PAD_TOKEN_ID
535-
assert payload["meta"]["finished"].item() is True
543+
out = payload.codes.audio
544+
assert out[0].item() == TALKER_CODEC_PAD_TOKEN_ID
545+
assert out[9].item() == TALKER_CODEC_PAD_TOKEN_ID
546+
assert payload.meta.finished.item() is True
536547

537548

538549
def test_mimo_audio_full_payload_nested_fallback() -> None:
@@ -548,10 +559,11 @@ def test_mimo_audio_full_payload_nested_fallback() -> None:
548559
audio = torch.arange(1 * 1 * 8 * 4, dtype=torch.long).reshape(1, 1, 8, 4)
549560
audio = audio.clamp(min=1)
550561
pooling_output = {"codes": {"audio": audio}} # nested, not flat
551-
req = SimpleNamespace(output_token_ids=[])
552-
payload = llm2code2wav_full_payload(None, pooling_output, req)
562+
tm = _make_mimo_transfer_manager()
563+
req = SimpleNamespace(req_id="test-nested")
564+
payload = llm2code2wav_full_payload(tm, pooling_output, req)
553565
assert payload is not None
554-
assert len(payload["codes"]["audio"]) == audio.numel() + int(audio.shape[0]) * 4
566+
assert payload.codes.audio.numel() == audio.numel() + int(audio.shape[0]) * 4
555567

556568

557569
def test_qwen3_tts_talker2code2wav_token_only_smoke() -> None:

0 commit comments

Comments
 (0)