-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexport_voxlingua_to_onnx.py
More file actions
116 lines (91 loc) · 4.22 KB
/
Copy pathexport_voxlingua_to_onnx.py
File metadata and controls
116 lines (91 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Export speechbrain/lang-id-voxlingua107-ecapa to ONNX.
Produces:
<out-dir>/voxlingua107.onnx — end-to-end graph (raw audio → logits, embedding)
<out-dir>/lang_map.json — class index → ISO code + English name
Usage:
python export_voxlingua_to_onnx.py --out-dir ./voxlingua107
"""
from __future__ import annotations
import argparse
from pathlib import Path
import onnx
import torch
from speechbrain.inference.classifiers import EncoderClassifier
from src.conv_stft import replace_speechbrain_stft
from src.ecapa_wrapper import VoxLinguaONNX
from src.lang_map import write_lang_map
MODEL_SOURCE = "speechbrain/lang-id-voxlingua107-ecapa"
# 3 seconds of 16 kHz audio — any non-trivial length works for tracing,
# but too short risks hitting ECAPA's internal pooling edge cases.
TRACE_SAMPLES = 16_000 * 3
def load_classifier(savedir: Path) -> EncoderClassifier:
savedir.mkdir(parents=True, exist_ok=True)
return EncoderClassifier.from_hparams(
source=MODEL_SOURCE,
savedir=str(savedir),
run_opts={"device": "cpu"}, # export always runs on CPU for determinism
)
def export(out_dir: Path, savedir: Path) -> None:
print(f"[voxlingua] loading {MODEL_SOURCE} into {savedir}")
classifier = load_classifier(savedir)
classifier.eval()
# Swap SpeechBrain's torch.stft-based STFT for a Conv1D implementation.
# torch.stft exports to the ONNX `STFT` op which has no CUDA kernel,
# forcing the whole preprocessing path back to host memory. Conv1D has
# kernels on every EP, so the preprocessing stays on-device.
sb_stft = classifier.mods.compute_features.compute_STFT
classifier.mods.compute_features.compute_STFT = replace_speechbrain_stft(sb_stft).eval()
print("[voxlingua] replaced STFT op with Conv1D implementation")
model = VoxLinguaONNX(classifier).eval()
dummy_audio = torch.randn(1, TRACE_SAMPLES, dtype=torch.float32)
# Sanity-check the forward pass produces the expected shapes before export.
with torch.no_grad():
logits, embedding = model(dummy_audio)
assert logits.shape == (1, 107), f"unexpected logits shape: {logits.shape}"
assert embedding.dim() == 2 and embedding.size(0) == 1, \
f"unexpected embedding shape: {embedding.shape}"
print(f"[voxlingua] forward ok: logits={tuple(logits.shape)}, "
f"embedding={tuple(embedding.shape)}")
out_dir.mkdir(parents=True, exist_ok=True)
onnx_path = out_dir / "voxlingua107.onnx"
print(f"[voxlingua] exporting to {onnx_path}")
torch.onnx.export(
model,
(dummy_audio,),
str(onnx_path),
input_names=["audio"],
output_names=["logits", "embedding"],
dynamic_axes={
"audio": {0: "batch", 1: "samples"},
"logits": {0: "batch"},
"embedding": {0: "batch"},
},
opset_version=17,
do_constant_folding=True,
)
# The dynamo exporter writes weights to a sidecar <name>.onnx.data by default.
# This model is small (~82 MB, well under the 2 GB protobuf ceiling), so
# re-save inline for a single-file distribution.
sidecar_path = onnx_path.with_suffix(onnx_path.suffix + ".data")
if sidecar_path.exists():
proto = onnx.load(str(onnx_path), load_external_data=True)
onnx.save(proto, str(onnx_path), save_as_external_data=False)
sidecar_path.unlink()
print(f"[voxlingua] merged {sidecar_path.name} → single-file onnx")
# Structural check — catches malformed graphs early.
onnx.checker.check_model(str(onnx_path))
print("[voxlingua] onnx.checker passed")
lang_map_path = out_dir / "lang_map.json"
write_lang_map(classifier, lang_map_path)
print(f"[voxlingua] wrote {lang_map_path} with 107 entries")
def main() -> None:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--out-dir", type=Path, required=True,
help="Destination directory for voxlingua107.onnx and lang_map.json.")
p.add_argument("--savedir", type=Path,
default=Path("./.voxlingua107-cache"),
help="Where SpeechBrain caches the downloaded weights.")
args = p.parse_args()
export(args.out_dir, args.savedir)
if __name__ == "__main__":
main()