Skip to content

Commit 4a8731c

Browse files
fix: register SongGen pipeline and wire server-side voice conditioning
Addresses the review on PR #4117 (issue #3388): - Register the 'songgen' model_type in the central pipeline_registry so the model can be resolved and served (it raised KeyError before). - Resolve request.ref_audio to a waveform in _build_songgen_params and pass it as ref_voice_array, the key the model reads; the old ref_voice_url was silently dropped, so server-side voice conditioning never worked. - Load the architecture named in the checkpoint config (Mixed vs DualTrack) and select bf16/fp16 on CUDA (fp32 on CPU), casting float inputs to the model dtype to avoid a generate() dtype mismatch. - Add tests: registry resolution, _build_songgen_params voice-conditioning unit tests, and offline/online e2e smoke tests. Signed-off-by: Arnav Nagzirkar <113314200+arnavnagzirkar@users.noreply.github.com>
1 parent 775205c commit 4a8731c

8 files changed

Lines changed: 306 additions & 9 deletions

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""E2E offline inference tests for the SongGen single-stage pipeline.
4+
5+
SongGen turns lyrics plus a music-style description into a 16 kHz mono song in
6+
one auto-regressive pass (the 1.3B AR LM and the X-Codec decoder both run
7+
inside ``SongGenForGeneration``). These mirror the offline example in
8+
``examples/offline_inference/text_to_speech/songgen/end2end.py``.
9+
10+
The model and its ``songgen`` package dependency are large, so these tests are
11+
gated behind the ``full_model`` / ``tts`` markers and only run in the model CI
12+
lane (the deploy config targets a single 80 GB GPU).
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import pytest
18+
import torch
19+
from vllm import SamplingParams
20+
21+
from tests.helpers.mark import hardware_test
22+
from tests.helpers.runtime import OmniRunner
23+
from tests.helpers.stage_config import get_deploy_config_path
24+
from vllm_omni import Omni
25+
26+
MODEL_NAME = "LiuZH-19/SongGen_mixed_pro"
27+
STAGE_CONFIG = get_deploy_config_path("songgen.yaml")
28+
29+
# (model, stage_configs_path) for the ``omni_runner`` indirect parametrize.
30+
_OMNI_RUNNER_PARAM = (
31+
MODEL_NAME,
32+
STAGE_CONFIG,
33+
)
34+
35+
pytestmark = [
36+
pytest.mark.full_model,
37+
pytest.mark.tts,
38+
pytest.mark.parametrize("omni_runner", [_OMNI_RUNNER_PARAM], indirect=True),
39+
]
40+
41+
SAMPLE_RATE = 16000
42+
43+
DEFAULT_SAMPLING = SamplingParams(
44+
temperature=1.0,
45+
top_p=1.0,
46+
top_k=50,
47+
max_tokens=4096,
48+
seed=42,
49+
detokenize=False,
50+
)
51+
52+
53+
def _build_request(lyrics: str, description: str = "a pop song", seed: int = 42) -> dict:
54+
"""Build a SongGen offline request (lyrics + style description)."""
55+
return {
56+
"prompt": "<|im_start|>assistant\n",
57+
"additional_information": {
58+
"lyrics": [lyrics],
59+
"text_description": [description],
60+
"seed": [seed],
61+
},
62+
}
63+
64+
65+
def _collect_audio(omni: Omni, request: dict) -> tuple[torch.Tensor, int]:
66+
"""Run a single request and return (waveform, sample_rate)."""
67+
for stage_outputs in omni.generate(request, DEFAULT_SAMPLING):
68+
req_output = stage_outputs.request_output
69+
if req_output is not None:
70+
mm = req_output.outputs[0].multimodal_output
71+
assert mm is not None, "Expected multimodal_output to be non-None"
72+
audio = mm.get("audio")
73+
sr = mm.get("sr")
74+
assert audio is not None, "Expected 'audio' key in multimodal_output"
75+
assert isinstance(audio, torch.Tensor), f"audio should be Tensor, got {type(audio)}"
76+
return audio.cpu(), int(sr.item()) if sr is not None else SAMPLE_RATE
77+
raise AssertionError("No stage outputs received")
78+
79+
80+
@pytest.mark.advanced_model
81+
@hardware_test(res={"cuda": "H100"}, num_cards=1)
82+
def test_songgen_text_to_song(omni_runner: OmniRunner) -> None:
83+
"""Lyrics + description produce non-empty 16 kHz audio."""
84+
req = _build_request("Under the moonlight, we dance through the night.")
85+
audio, sr = _collect_audio(omni_runner.omni, req)
86+
87+
assert sr == SAMPLE_RATE, f"Expected sample_rate={SAMPLE_RATE}, got {sr}"
88+
assert audio.numel() > 0, "Audio tensor should not be empty"
89+
assert not torch.all(audio == 0), "Audio should not be all-zeros (silence)"
90+
91+
92+
@pytest.mark.advanced_model
93+
@hardware_test(res={"cuda": "H100"}, num_cards=1)
94+
def test_songgen_batch(omni_runner: OmniRunner) -> None:
95+
"""Batch of two requests returns audio for each."""
96+
requests = [
97+
_build_request("First verse under a quiet sky."),
98+
_build_request("Second verse as the morning breaks."),
99+
]
100+
results = []
101+
# Single-stage model (num_stages=1): one sampling param for all requests.
102+
for stage_outputs in omni_runner.omni.generate(requests, [DEFAULT_SAMPLING]):
103+
req_output = stage_outputs.request_output
104+
if req_output is not None:
105+
mm = req_output.outputs[0].multimodal_output
106+
assert mm is not None
107+
results.append(mm["audio"].cpu())
108+
109+
assert len(results) == 2, f"Expected 2 outputs, got {len(results)}"
110+
for i, audio in enumerate(results):
111+
assert audio.numel() > 0, f"Audio {i} is empty"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
E2E online tests for SongGen via the /v1/audio/speech endpoint.
5+
6+
SongGen maps the OpenAI speech contract onto text-to-song generation:
7+
- ``input`` -> song lyrics (required)
8+
- ``instructions`` -> music style / genre description (optional)
9+
- ``ref_audio`` -> reference voice for timbre conditioning (optional)
10+
11+
The server resolves ``ref_audio`` to a waveform and forwards it to the model as
12+
``ref_voice_array`` (the key the model reads); a minimal non-streaming WAV case
13+
is enough to exercise the full serving path end to end. These tests are gated
14+
behind ``full_model`` / ``tts`` and run only in the model CI lane.
15+
"""
16+
17+
import os
18+
19+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
20+
21+
import pytest
22+
23+
from tests.helpers.mark import hardware_test
24+
from tests.helpers.runtime import OmniServerParams
25+
from tests.helpers.stage_config import get_deploy_config_path
26+
27+
pytestmark = [pytest.mark.full_model, pytest.mark.tts]
28+
29+
MODEL = "LiuZH-19/SongGen_mixed_pro"
30+
LYRICS = "Under the moonlight, we dance through the night, stars above shining bright."
31+
DESCRIPTION = "dreamy pop ballad with piano and strings"
32+
33+
# A 16 kHz song clip is far larger than this floor; the check only guards
34+
# against an empty or truncated response, not audio quality.
35+
_MIN_AUDIO_BYTES = 20_000
36+
37+
songgen_server_params = [
38+
pytest.param(
39+
OmniServerParams(
40+
model=MODEL,
41+
stage_config_path=get_deploy_config_path("songgen.yaml"),
42+
server_args=["--disable-log-stats"],
43+
),
44+
id="songgen",
45+
)
46+
]
47+
48+
49+
@hardware_test(res={"cuda": "H100"}, num_cards=1)
50+
@pytest.mark.parametrize("omni_server", songgen_server_params, indirect=True)
51+
def test_text_to_song_001(omni_server, openai_client) -> None:
52+
"""
53+
Text-to-song via /v1/audio/speech (lyrics + style description).
54+
Deploy Setting: default yaml
55+
Input Modal: text (lyrics) + instructions (style description)
56+
Output Modal: audio (16 kHz, WAV)
57+
Input Setting: stream=False
58+
Datasets: single request
59+
"""
60+
request_config = {
61+
"model": omni_server.model,
62+
"input": LYRICS,
63+
"instructions": DESCRIPTION,
64+
"stream": False,
65+
"response_format": "wav",
66+
"min_audio_bytes": _MIN_AUDIO_BYTES,
67+
}
68+
69+
openai_client.send_audio_speech_request(request_config)

tests/entrypoints/openai_api/test_serving_speech.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,3 +2926,58 @@ def test_diffusion_instance_shutdown_safe(self, mocker: MockerFixture):
29262926
server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=mocker.MagicMock(), model_name="test-model")
29272927
assert server._tts_executor is None
29282928
server.shutdown() # Should not raise
2929+
2930+
2931+
class TestSongGenParams:
2932+
"""Unit tests for ``_build_songgen_params`` (issue #3388 review).
2933+
2934+
These guard two regressions in the SongGen serving path:
2935+
* the talker must receive the resolved reference waveform under the
2936+
``ref_voice_array`` key the model actually reads, not a ``ref_voice_url``
2937+
key it silently ignores; and
2938+
* the no-reference path must not emit any voice-conditioning key.
2939+
2940+
A lightweight stub ``self`` (only ``_resolve_ref_audio`` is touched) keeps
2941+
the test on the CPU/core CI lane with no GPU, weights, or network.
2942+
"""
2943+
2944+
@staticmethod
2945+
def _build(request, resolved=([0.0, 0.25, -0.25, 0.5], 24000)):
2946+
calls: list[str] = []
2947+
2948+
async def _fake_resolve(ref_audio_str):
2949+
calls.append(ref_audio_str)
2950+
return resolved
2951+
2952+
stub = SimpleNamespace(_resolve_ref_audio=_fake_resolve)
2953+
params = asyncio.run(OmniOpenAIServingSpeech._build_songgen_params(stub, request))
2954+
return params, calls
2955+
2956+
def test_without_ref_audio_emits_no_voice_key(self):
2957+
request = OpenAICreateSpeechRequest(
2958+
input="la la la under the moonlight",
2959+
instructions="dreamy pop ballad with piano",
2960+
)
2961+
params, calls = self._build(request)
2962+
assert params["lyrics"] == ["la la la under the moonlight"]
2963+
assert params["text_description"] == ["dreamy pop ballad with piano"]
2964+
assert "ref_voice_array" not in params
2965+
assert "ref_voice_url" not in params
2966+
assert calls == []
2967+
2968+
def test_ref_audio_resolved_to_array_not_url(self):
2969+
request = OpenAICreateSpeechRequest(
2970+
input="sing me a song",
2971+
instructions="a pop song",
2972+
ref_audio="data:audio/wav;base64,AAAA",
2973+
)
2974+
params, calls = self._build(request)
2975+
# The model consumes ref_voice_array=[[wav, sr]]; ref_voice_url is dead.
2976+
assert "ref_voice_url" not in params
2977+
assert params["ref_voice_array"] == [[[0.0, 0.25, -0.25, 0.5], 24000]]
2978+
assert calls == ["data:audio/wav;base64,AAAA"]
2979+
2980+
def test_missing_instructions_default_empty_description(self):
2981+
request = OpenAICreateSpeechRequest(input="just lyrics")
2982+
params, _ = self._build(request)
2983+
assert params["text_description"] == [""]

tests/helpers/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1767,7 +1767,7 @@ def send_audio_speech_request(self, request_config: dict[str, Any], request_num:
17671767
# Qwen3-TTS custom fields, forwarded via extra_body.
17681768
extra_body: dict[str, Any] = {}
17691769
# Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params.
1770-
for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens", "seed"):
1770+
for key in ("task_type", "ref_text", "ref_audio", "instructions", "language", "max_new_tokens", "seed"):
17711771
if key in request_config:
17721772
extra_body[key] = request_config[key]
17731773

tests/test_config_factory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,18 @@ def test_registry_loads_pipeline_on_getitem(self):
913913
assert pipeline.model_type == "qwen3_omni_moe"
914914
assert len(pipeline.stages) == 3 # thinker + talker + code2wav
915915

916+
def test_registry_has_songgen(self):
917+
"""SongGen's single-stage pipeline is registered and resolvable.
918+
919+
Without this central entry, ``model_type='songgen'`` raises KeyError and
920+
the model cannot be served at all (regression from issue #3388 review).
921+
"""
922+
assert "songgen" in _PIPELINE_REGISTRY
923+
pipeline = _PIPELINE_REGISTRY["songgen"]
924+
assert pipeline.model_type == "songgen"
925+
assert len(pipeline.stages) == 1 # single-stage AR generator
926+
assert pipeline.stages[0].final_output_type == "audio"
927+
916928
def test_registry_returns_none_for_unknown(self):
917929
"""Unknown model_types aren't found; ``get()`` returns None."""
918930
assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY

vllm_omni/config/pipeline_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@
141141
"vllm_omni.model_executor.models.moss_tts.pipeline",
142142
"MOSS_TTS_REALTIME_PIPELINE",
143143
),
144+
# SongGen (text-to-song): single-stage AR generator, MOSS-TTS-Nano lineage.
145+
"songgen": (
146+
"vllm_omni.model_executor.models.songgen.pipeline",
147+
"SONGGEN_PIPELINE",
148+
),
144149
"minicpmo_4_5": (
145150
"vllm_omni.model_executor.models.minicpmo_4_5.pipeline",
146151
"MINICPMO_4_5_PIPELINE",

vllm_omni/entrypoints/openai/serving_speech.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,14 +1890,24 @@ async def _build_songgen_params(self, request: OpenAICreateSpeechRequest) -> dic
18901890
Returns a dict with keys expected by SongGenForGeneration._create_stream_gen():
18911891
lyrics : list[str] - song lyrics (from request.input)
18921892
text_description : list[str] - style / genre description
1893-
ref_voice_url : list[str] - reference voice audio URL (if provided)
1893+
ref_voice_array : list[[list[float], int]] - resolved reference
1894+
voice waveform + sample rate (only when ref_audio
1895+
is provided)
1896+
1897+
The model reads ``ref_voice_array`` (a resolved ``[wav_samples, sr]``
1898+
pair that it stages to a temp WAV), not a URL. We resolve
1899+
``request.ref_audio`` to a waveform here so the server path matches the
1900+
offline example (``ref_voice_array=[[wav, sr]]``) and the MOSS-TTS
1901+
``prompt_audio_array`` convention; emitting ``ref_voice_url`` instead
1902+
would be silently dropped by the model.
18941903
"""
18951904
params: dict = {
18961905
"lyrics": [request.input],
18971906
"text_description": [request.instructions or ""],
18981907
}
18991908
if request.ref_audio is not None:
1900-
params["ref_voice_url"] = [request.ref_audio]
1909+
wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
1910+
params["ref_voice_array"] = [[wav_list, sr]]
19011911
return params
19021912

19031913
async def _build_higgs_audio_v2_params(self, request: OpenAICreateSpeechRequest):

vllm_omni/model_executor/models/songgen/modeling_songgen.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,47 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
113113
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114114
self._device = device
115115

116-
logger.info("Loading SongGen from %s on %s", self.model_path, device)
117-
116+
# Match the MOSS-TTS convention: bf16 on bf16-capable CUDA, fp16 on
117+
# older CUDA, fp32 on CPU. The deploy config targets H100/A100, so
118+
# the bf16 path is the common one in practice.
119+
if device.type == "cuda" and torch.cuda.is_bf16_supported():
120+
model_dtype = torch.bfloat16
121+
elif device.type == "cuda":
122+
model_dtype = torch.float16
123+
else:
124+
model_dtype = torch.float32
125+
126+
# SongGen ships two architectures (Mixed and DualTrack). Both are
127+
# served by this wrapper, but they are distinct upstream classes, so
128+
# load the one named in the checkpoint config instead of always
129+
# loading Mixed (which would load the wrong weights for a DualTrack
130+
# checkpoint).
131+
architectures = list(getattr(self.config, "architectures", None) or [])
132+
want_dualtrack = any("DualTrack" in arch for arch in architectures)
118133
try:
119-
from songgen import SongGenMixedForConditionalGeneration, SongGenProcessor
134+
from songgen import SongGenProcessor
135+
136+
if want_dualtrack:
137+
from songgen import SongGenDualTrackForConditionalGeneration as _SongGenModelClass
138+
else:
139+
from songgen import SongGenMixedForConditionalGeneration as _SongGenModelClass
120140
except ImportError as exc:
121141
raise ImportError(
122142
"SongGen requires the 'songgen' package. "
123143
"Install it from: pip install git+https://github.com/LiuZH-19/SongGen.git"
124144
) from exc
125145

126-
model = SongGenMixedForConditionalGeneration.from_pretrained(
146+
logger.info(
147+
"Loading SongGen (%s) from %s on %s (dtype=%s)",
148+
_SongGenModelClass.__name__,
149+
self.model_path,
150+
device,
151+
model_dtype,
152+
)
153+
model = _SongGenModelClass.from_pretrained(
127154
self.model_path,
128155
attn_implementation="sdpa",
129-
torch_dtype=torch.float32,
156+
torch_dtype=model_dtype,
130157
)
131158
model.to(device=device)
132159
model.eval()
@@ -200,8 +227,16 @@ def _create_stream_gen(self, info: dict[str, Any]):
200227
separate=False,
201228
return_tensors="pt",
202229
)
230+
# Move every tensor to the model device. Floating-point inputs (e.g.
231+
# reference-voice features) are also cast to the model dtype so the
232+
# bf16/fp16 weight path does not hit a dtype mismatch inside
233+
# generate(); integer token ids keep their dtype.
234+
model_dtype = next(self._model.parameters()).dtype
203235
model_inputs = {
204-
k: v.to(self._device) if isinstance(v, torch.Tensor) else v for k, v in model_inputs.items()
236+
k: (v.to(device=self._device, dtype=model_dtype) if v.is_floating_point() else v.to(self._device))
237+
if isinstance(v, torch.Tensor)
238+
else v
239+
for k, v in model_inputs.items()
205240
}
206241

207242
output = self._model.generate(**model_inputs, do_sample=_DEFAULT_DO_SAMPLE)

0 commit comments

Comments
 (0)