Skip to content

Commit ccfc811

Browse files
Merge pull request #2 from hubertsiuzdak/v1.0
v1.0
2 parents ccd0844 + 85247e0 commit ccfc811

5 files changed

Lines changed: 110 additions & 17 deletions

File tree

README.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# [WIP] SNAC 🍿
1+
# SNAC 🍿
22

3-
Multi-**S**cale **N**eural **A**udio **C**odec (SNAC) compressess 44.1 kHz audio into discrete codes at a low bitrate.
3+
Multi-**S**cale **N**eural **A**udio **C**odec (SNAC) compressess audio into discrete codes at a low bitrate.
44

55
## Overview
66

@@ -14,6 +14,15 @@ consistent structure of an audio track for ~3 minutes.
1414

1515
![snac.png](img%2Fsnac.png)
1616

17+
## Pretrained models
18+
19+
| Model | Bitrate | Sample Rate |
20+
|-----------------------------------------------------------------------------|----------|-------------|
21+
| [hubertsiuzdak/snac_32khz](https://huggingface.co/hubertsiuzdak/snac_32khz) | 1.9 kbps | 32 kHz |
22+
| [hubertsiuzdak/snac_44khz](https://huggingface.co/hubertsiuzdak/snac_44khz) | 2.6 kbps | 44 kHz |
23+
24+
These models were trained mostly on music.
25+
1726
## Usage
1827

1928
Install it using:
@@ -22,18 +31,14 @@ Install it using:
2231
pip install snac
2332
```
2433

25-
A pretrained model that compresses audio into discrete codes at a 2.2 kbps bitrate is available
26-
at [Hugging Face](https://huggingface.co/hubertsiuzdak/snac). It uses 4 RVQ levels with token rates of 12.5, 25, 50, and
27-
100 Hz.
28-
2934
To encode (and reconstruct) audio with SNAC in Python, use the following code:
3035

3136
```python
3237
import torch
3338
from snac import SNAC
3439

35-
model = SNAC.from_pretrained("hubertsiuzdak/snac").eval().cuda()
36-
audio = torch.randn(1, 1, 44100).cuda() # B, 1, T
40+
model = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().cuda()
41+
audio = torch.randn(1, 1, 32000).cuda() # B, 1, T
3742

3843
with torch.inference_mode():
3944
audio_hat, _, codes, _, _ = model(audio)
@@ -44,7 +49,7 @@ resolution.
4449

4550
```
4651
>>> [code.shape[1] for code in codes]
47-
[13, 26, 52, 104]
52+
[12, 24, 48, 96]
4853
```
4954

5055
## Acknowledgements

snac/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .snac import SNAC
22

3-
__version__ = "0.1.0"
3+
__version__ = "1.0.0"

snac/attention.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
from einops import rearrange
3+
from torch import nn
4+
5+
6+
class LocalMHA(nn.Module):
7+
def __init__(self, dim=1024, window_size=32, dim_head=64, use_rotary_pos_emb=True):
8+
super().__init__()
9+
self.norm = nn.LayerNorm(dim)
10+
self.heads = dim // dim_head
11+
self.window_size = window_size
12+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
13+
if use_rotary_pos_emb:
14+
self.rel_pos = SinusoidalEmbeddings(dim_head, scale_base=window_size // 2)
15+
else:
16+
self.rel_pos = None
17+
self.to_out = nn.Linear(dim, dim, bias=False)
18+
19+
def forward(self, x):
20+
B, C, T = x.shape
21+
residual = x
22+
x = self.norm(x.transpose(1, 2))
23+
windows = T // self.window_size
24+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
25+
q, k, v = map(lambda t: rearrange(t, "b (w n) (h d) -> b h w n d", w=windows, h=self.heads), (q, k, v))
26+
if self.rel_pos is not None:
27+
pos_emb, scale = self.rel_pos(k)
28+
q, k = apply_rotary_pos_emb(q, k, pos_emb, scale)
29+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
30+
out = rearrange(out, "b h w n d -> b (w n) (h d)")
31+
out = self.to_out(out)
32+
return out.transpose(1, 2) + residual
33+
34+
35+
class SinusoidalEmbeddings(nn.Module):
36+
def __init__(self, dim, scale_base=None, use_xpos=False):
37+
super().__init__()
38+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
39+
self.register_buffer("inv_freq", inv_freq)
40+
# xpos related
41+
self.use_xpos = use_xpos
42+
self.scale_base = scale_base
43+
assert not (use_xpos and scale_base is None), "scale base must be defined if using xpos"
44+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
45+
self.register_buffer("scale", scale, persistent=False)
46+
47+
def forward(self, x):
48+
seq_len, device = x.shape[-2], x.device
49+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
50+
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
51+
freqs = torch.cat((freqs, freqs), dim=-1)
52+
if not self.use_xpos:
53+
return freqs, torch.ones(1, device=device)
54+
power = (t - (seq_len // 2)) / self.scale_base
55+
scale = self.scale ** rearrange(power, "n -> n 1")
56+
scale = torch.cat((scale, scale), dim=-1)
57+
58+
return freqs, scale
59+
60+
61+
def rotate_half(x):
62+
x = rearrange(x, "b ... (r d) -> b ... r d", r=2)
63+
x1, x2 = x.unbind(dim=-2)
64+
return torch.cat((-x2, x1), dim=-1)
65+
66+
67+
def apply_rotary_pos_emb(q, k, freqs, scale=1):
68+
q_len = q.shape[-2]
69+
q_freqs = freqs[..., -q_len:, :]
70+
inv_scale = scale**-1
71+
if scale.ndim == 2:
72+
scale = scale[-q_len:, :]
73+
q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
74+
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
75+
return q, k

snac/layers.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44
import torch.nn as nn
55
from torch.nn.utils.parametrizations import weight_norm
66

7+
from .attention import LocalMHA
8+
79

810
class Encoder(nn.Module):
911
def __init__(
1012
self,
1113
d_model=64,
1214
strides=[3, 3, 7, 7],
1315
depthwise=False,
16+
attn_window_size=32,
1417
):
1518
super().__init__()
1619
layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
1720
for stride in strides:
1821
d_model *= 2
1922
groups = d_model // 2 if depthwise else 1
2023
layers += [EncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
24+
if attn_window_size is not None:
25+
layers += [LocalMHA(dim=d_model, window_size=attn_window_size)]
2126
groups = d_model if depthwise else 1
2227
layers += [
23-
Snake1d(d_model),
2428
WNConv1d(d_model, d_model, kernel_size=7, padding=3, groups=groups),
2529
]
2630
self.block = nn.Sequential(*layers)
@@ -37,18 +41,21 @@ def __init__(
3741
rates,
3842
noise=False,
3943
depthwise=False,
44+
attn_window_size=32,
4045
d_out=1,
4146
):
4247
super().__init__()
4348
if depthwise:
4449
layers = [
50+
WNConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
4551
WNConv1d(input_channel, channels, kernel_size=1),
46-
Snake1d(channels),
47-
WNConv1d(channels, channels, kernel_size=7, padding=3, groups=channels),
4852
]
4953
else:
5054
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
5155

56+
if attn_window_size is not None:
57+
layers += [LocalMHA(dim=channels, window_size=attn_window_size)]
58+
5259
for i, stride in enumerate(rates):
5360
input_dim = channels // 2**i
5461
output_dim = channels // 2 ** (i + 1)
@@ -111,13 +118,14 @@ def forward(self, x):
111118
class NoiseBlock(nn.Module):
112119
def __init__(self, dim):
113120
super().__init__()
114-
self.scale = nn.Parameter(torch.zeros(dim, 1))
121+
self.linear = WNConv1d(dim, dim, kernel_size=1, bias=False)
115122

116123
def forward(self, x):
117124
B, C, T = x.shape
118125
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
119-
noise_scaled = noise * self.scale
120-
x = x + noise_scaled
126+
h = self.linear(x)
127+
n = noise * h
128+
x = x + n
121129
return x
122130

123131

snac/snac.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212
class SNAC(nn.Module):
1313
def __init__(
1414
self,
15+
sampling_rate=44100,
1516
encoder_dim=64,
1617
encoder_rates=[3, 3, 7, 7],
1718
latent_dim=None,
1819
decoder_dim=1536,
1920
decoder_rates=[7, 7, 3, 3],
21+
attn_window_size=32,
2022
codebook_size=4096,
2123
codebook_dim=8,
2224
vq_strides=[8, 4, 2, 1],
2325
noise=True,
2426
depthwise=True,
2527
):
2628
super().__init__()
29+
self.sampling_rate = sampling_rate
2730
self.encoder_dim = encoder_dim
2831
self.encoder_rates = encoder_rates
2932
self.decoder_dim = decoder_dim
@@ -37,6 +40,7 @@ def __init__(
3740
self.codebook_size = codebook_size
3841
self.codebook_dim = codebook_dim
3942
self.vq_strides = vq_strides
43+
self.attn_window_size = attn_window_size
4044
self.quantizer = ResidualVectorQuantize(
4145
input_dim=latent_dim,
4246
codebook_size=codebook_size,
@@ -53,7 +57,7 @@ def __init__(
5357

5458
def preprocess(self, audio_data):
5559
length = audio_data.shape[-1]
56-
pad_to = self.hop_length * self.vq_strides[0]
60+
pad_to = self.hop_length * self.attn_window_size
5761
right_pad = math.ceil(length / pad_to) * pad_to - length
5862
audio_data = nn.functional.pad(audio_data, (0, right_pad))
5963
return audio_data
@@ -76,6 +80,7 @@ def from_config(cls, config_path):
7680
@classmethod
7781
def from_pretrained(cls, repo_id, **kwargs):
7882
from huggingface_hub import hf_hub_download
83+
7984
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", **kwargs)
8085
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", **kwargs)
8186
model = cls.from_config(config_path)

0 commit comments

Comments
 (0)