Skip to content

Commit 1ba6d08

Browse files
committed
add kws with transformers, and bitnet train script update
1 parent 8c5cf87 commit 1ba6d08

9 files changed

Lines changed: 425 additions & 53 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ temp
33
data
44

55
.vscode
6-
6+
*.log
77
output/

MiCoCodeGen.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,14 @@ def handle_call_method(self, n: torch.fx.node.Node, out: torch.Tensor):
497497
self.add_forward_call("MiCo_CONNECT", out, n.name, [src_name])
498498
elif method == "mean":
499499
src_name = input_names[0]
500-
dim = self._resolve_arg_value(n.args[1])
501-
keepdim = self._resolve_arg_value(n.args[2]) if len(n.args) > 2 else False
500+
dim = self._resolve_arg_value(n.args[1]) if len(n.args) > 1 else self._resolve_arg_value(n.kwargs.get("dim", None))
501+
keepdim = (
502+
self._resolve_arg_value(n.args[2])
503+
if len(n.args) > 2
504+
else self._resolve_arg_value(n.kwargs.get("keepdim", False))
505+
)
506+
if dim is None:
507+
raise NotImplementedError("Mean over all elements is not supported")
502508
self.add_uninitialized_tensor(n.name, out)
503509
if keepdim:
504510
self.add_forward_call(f"MiCo_meankp{out.dim()}d_{{dtype}}", out, n.name, [src_name], [dim])

MiCoRegistry.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -233,24 +233,51 @@ def handle_tanh(codegen, n, out, input_names, input_args):
233233
codegen.add_forward_call("MiCo_tanh{dim}d_{dtype}", out, n.name, input_names)
234234

235235

236+
def _extract_scalar_param(param, param_name, default=None):
237+
"""Extract a scalar C API parameter from PyTorch int/tuple pooling args."""
238+
if param is None:
239+
if default is None:
240+
raise ValueError(f"{param_name} cannot be None")
241+
param = default
242+
243+
if isinstance(param, torch.fx.node.Node):
244+
raise ValueError(f"Unresolved FX node for {param_name}: {param}")
245+
246+
if isinstance(param, torch.Size):
247+
param = tuple(param)
248+
249+
if isinstance(param, (tuple, list)):
250+
if len(param) == 0:
251+
raise ValueError(f"{param_name} cannot be empty")
252+
first = param[0]
253+
if any(value != first for value in param):
254+
raise NotImplementedError(
255+
f"MiCo C pooling kernels only support scalar/symmetric {param_name}, got {param}"
256+
)
257+
param = first
258+
259+
if isinstance(param, bool) or not isinstance(param, int):
260+
raise ValueError(f"Unexpected {param_name} type: {type(param)}")
261+
return param
262+
263+
236264
def _extract_kernel_size(param):
237-
"""Helper to extract kernel size from tuple or int parameter."""
238-
if isinstance(param, Tuple):
239-
return param[0]
240-
elif isinstance(param, int):
241-
return param
242-
else:
243-
raise ValueError(f"Unexpected kernel_size type: {type(param)}")
265+
"""Helper to extract scalar kernel size for the C pooling API."""
266+
return _extract_scalar_param(param, "kernel_size")
244267

245268

246269
def _extract_output_size(param):
247-
"""Helper to extract output size from tuple or int parameter."""
248-
if isinstance(param, Tuple):
249-
return param[0]
250-
elif isinstance(param, int):
251-
return param
270+
"""Helper to extract scalar output size for the C adaptive pooling API."""
271+
return _extract_scalar_param(param, "output_size")
272+
273+
274+
def _pool_arg(n, input_args, index, name, default=None):
275+
"""Read pooling arg from positional or keyword FX args and normalize it."""
276+
if len(input_args) > index:
277+
value = input_args[index]
252278
else:
253-
raise ValueError(f"Unexpected output_size type: {type(param)}")
279+
value = n.kwargs.get(name, default)
280+
return _extract_scalar_param(value, name, default)
254281

255282

256283
@MiCoOpRegistry.register_function(torch.nn.functional.linear)
@@ -275,20 +302,22 @@ def handle_linear(codegen, n, out, input_names, input_args):
275302
def handle_avg_pool2d(codegen, n, out, input_names, input_args):
276303
"""Handler for 2D average pooling function."""
277304
codegen.add_uninitialized_tensor(n.name, out)
278-
kernel_size = _extract_kernel_size(input_args[1])
279-
stride = input_args[2] if len(input_args) > 2 else 1
305+
kernel_size = _pool_arg(n, input_args, 1, "kernel_size")
306+
stride = _pool_arg(n, input_args, 2, "stride", kernel_size)
307+
padding = _pool_arg(n, input_args, 3, "padding", 0)
280308
codegen.add_forward_call("MiCo_avgpool{dim}d_{dtype}", out, n.name, input_names,
281-
[kernel_size, stride])
309+
[kernel_size, stride, padding])
282310

283311

284312
@MiCoOpRegistry.register_function(torch.nn.functional.max_pool2d)
285313
def handle_max_pool2d(codegen, n, out, input_names, input_args):
286314
"""Handler for 2D max pooling function."""
287315
codegen.add_uninitialized_tensor(n.name, out)
288-
kernel_size = _extract_kernel_size(input_args[1])
289-
stride = input_args[2] if len(input_args) > 2 else 1
316+
kernel_size = _pool_arg(n, input_args, 1, "kernel_size")
317+
stride = _pool_arg(n, input_args, 2, "stride", kernel_size)
318+
padding = _pool_arg(n, input_args, 3, "padding", 0)
290319
codegen.add_forward_call("MiCo_maxpool{dim}d_{dtype}", out, n.name, input_names,
291-
[kernel_size, stride])
320+
[kernel_size, stride, padding])
292321

293322

294323
@MiCoOpRegistry.register_function(torch.nn.functional.adaptive_avg_pool2d)
@@ -303,20 +332,22 @@ def handle_adaptive_avg_pool2d(codegen, n, out, input_names, input_args):
303332
def handle_avg_pool1d(codegen, n, out, input_names, input_args):
304333
"""Handler for 1D average pooling function."""
305334
codegen.add_uninitialized_tensor(n.name, out)
306-
kernel_size = _extract_kernel_size(input_args[1])
307-
stride = input_args[2] if len(input_args) > 2 else 1
335+
kernel_size = _pool_arg(n, input_args, 1, "kernel_size")
336+
stride = _pool_arg(n, input_args, 2, "stride", kernel_size)
337+
padding = _pool_arg(n, input_args, 3, "padding", 0)
308338
codegen.add_forward_call("MiCo_avgpool{dim}d_{dtype}", out, n.name, input_names,
309-
[kernel_size, stride])
339+
[kernel_size, stride, padding])
310340

311341

312342
@MiCoOpRegistry.register_function(torch.nn.functional.max_pool1d)
313343
def handle_max_pool1d(codegen, n, out, input_names, input_args):
314344
"""Handler for 1D max pooling function."""
315345
codegen.add_uninitialized_tensor(n.name, out)
316-
kernel_size = _extract_kernel_size(input_args[1])
317-
stride = input_args[2] if len(input_args) > 2 else 1
346+
kernel_size = _pool_arg(n, input_args, 1, "kernel_size")
347+
stride = _pool_arg(n, input_args, 2, "stride", kernel_size)
348+
padding = _pool_arg(n, input_args, 3, "padding", 0)
318349
codegen.add_forward_call("MiCo_maxpool{dim}d_{dtype}", out, n.name, input_names,
319-
[kernel_size, stride])
350+
[kernel_size, stride, padding])
320351

321352

322353
@MiCoOpRegistry.register_function(torch.nn.functional.adaptive_avg_pool1d)
@@ -550,8 +581,10 @@ def handle_avgpool2d_module(codegen, n, out, module, input_names):
550581
layer_name = n.name
551582
codegen.add_uninitialized_tensor(layer_name, out)
552583
kernel_size = _extract_kernel_size(module.kernel_size)
584+
stride = _extract_scalar_param(module.stride, "stride", kernel_size)
585+
padding = _extract_scalar_param(module.padding, "padding", 0)
553586
codegen.add_forward_call("MiCo_avgpool{dim}d_{dtype}", out, layer_name, input_names,
554-
[kernel_size, module.stride, module.padding])
587+
[kernel_size, stride, padding])
555588

556589

557590
@MiCoOpRegistry.register_module(torch.nn.MaxPool2d)
@@ -560,8 +593,10 @@ def handle_maxpool2d_module(codegen, n, out, module, input_names):
560593
layer_name = n.name
561594
codegen.add_uninitialized_tensor(layer_name, out)
562595
kernel_size = _extract_kernel_size(module.kernel_size)
596+
stride = _extract_scalar_param(module.stride, "stride", kernel_size)
597+
padding = _extract_scalar_param(module.padding, "padding", 0)
563598
codegen.add_forward_call("MiCo_maxpool{dim}d_{dtype}", out, layer_name, input_names,
564-
[kernel_size, module.stride, module.padding])
599+
[kernel_size, stride, padding])
565600

566601

567602
@MiCoOpRegistry.register_module(torch.nn.AdaptiveAvgPool2d)
@@ -579,8 +614,10 @@ def handle_avgpool1d_module(codegen, n, out, module, input_names):
579614
layer_name = n.name
580615
codegen.add_uninitialized_tensor(layer_name, out)
581616
kernel_size = _extract_kernel_size(module.kernel_size)
617+
stride = _extract_scalar_param(module.stride, "stride", kernel_size)
618+
padding = _extract_scalar_param(module.padding, "padding", 0)
582619
codegen.add_forward_call("MiCo_avgpool{dim}d_{dtype}", out, layer_name, input_names,
583-
[kernel_size, module.stride, module.padding])
620+
[kernel_size, stride, padding])
584621

585622

586623
@MiCoOpRegistry.register_module(torch.nn.MaxPool1d)
@@ -589,8 +626,10 @@ def handle_maxpool1d_module(codegen, n, out, module, input_names):
589626
layer_name = n.name
590627
codegen.add_uninitialized_tensor(layer_name, out)
591628
kernel_size = _extract_kernel_size(module.kernel_size)
629+
stride = _extract_scalar_param(module.stride, "stride", kernel_size)
630+
padding = _extract_scalar_param(module.padding, "padding", 0)
592631
codegen.add_forward_call("MiCo_maxpool{dim}d_{dtype}", out, layer_name, input_names,
593-
[kernel_size, module.stride, module.padding])
632+
[kernel_size, stride, padding])
594633

595634

596635
@MiCoOpRegistry.register_module(torch.nn.AdaptiveAvgPool1d)

examples/mpq_train_bitnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
argsparse.add_argument("--lr", type=float, default=0.001)
1515
argsparse.add_argument("-q", "--weight_quant", type=float, choices=[1,1.5,2], default=1)
1616
argsparse.add_argument("-aq", "--act_quant", type=int, choices=[4,8], default=8)
17+
argsparse.add_argument("--use-norm", action="store_true", default=False)
1718
argsparse.add_argument("--keep-last", action="store_true", default=False)
1819
argsparse.add_argument("--keep-first", action="store_true", default=False)
1920
argsparse.add_argument("--scheduler", type=str, default="none")
@@ -26,6 +27,7 @@
2627
scheduler = args.scheduler
2728
weight_quant = args.weight_quant
2829
act_quant = args.act_quant
30+
use_norm = args.use_norm
2931
keep_last = args.keep_last
3032
keep_first = args.keep_first
3133

@@ -48,7 +50,7 @@
4850
qscheme[0][-1] = 8
4951
qscheme[1][-1] = 8
5052

51-
model.set_qscheme(qscheme, qat=True, use_norm=True)
53+
model.set_qscheme(qscheme, qat=True, use_norm=use_norm)
5254
print("Model Param Size:", sum(p.numel() for p in model.parameters()))
5355
# Detect if there is a full precision checkpoint
5456
# if os.path.exists(f"output/ckpt/{model_name}.pth"):
@@ -75,7 +77,7 @@
7577
torch.save(model.state_dict(), f"output/ckpt/{model_name}_bitnet.pth")
7678
print("Model Train Results: ", res)
7779

78-
model.set_qscheme(qscheme, qat=True, use_norm=True)
80+
model.set_qscheme(qscheme, qat=True, use_norm=use_norm)
7981

8082
res = model.test(test_loader)
8183

models/KWSTransformer.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
from torch import nn
3+
4+
from MiCoModel import MiCoModel
5+
from models.CCT import TransformerClassifier
6+
7+
8+
class MFCCTokenizer(nn.Module):
9+
def __init__(
10+
self,
11+
in_channels: int = 1,
12+
embedding_dim: int = 96,
13+
n_conv_layers: int = 2,
14+
kernel_size=(3, 3),
15+
stride=(1, 1),
16+
padding=(1, 1),
17+
pooling_kernel_size=(2, 2),
18+
pooling_stride=(2, 2),
19+
pooling_padding=(0, 0),
20+
):
21+
super().__init__()
22+
23+
channels = [in_channels] + [embedding_dim] * n_conv_layers
24+
layers = []
25+
for i in range(n_conv_layers):
26+
layers.extend(
27+
[
28+
nn.Conv2d(
29+
channels[i],
30+
channels[i + 1],
31+
kernel_size=kernel_size,
32+
stride=stride,
33+
padding=padding,
34+
bias=False,
35+
),
36+
nn.BatchNorm2d(channels[i + 1]),
37+
nn.ReLU(),
38+
nn.MaxPool2d(
39+
kernel_size=pooling_kernel_size,
40+
stride=pooling_stride,
41+
padding=pooling_padding,
42+
),
43+
]
44+
)
45+
46+
self.layers = nn.Sequential(*layers)
47+
self.flatten = nn.Flatten(2, 3)
48+
49+
def sequence_length(self, input_size):
50+
with torch.no_grad():
51+
x = torch.zeros(1, 1, input_size[0], input_size[1])
52+
return self.forward(x).shape[1]
53+
54+
def forward(self, x):
55+
x = self.layers(x)
56+
return self.flatten(x).transpose(1, 2)
57+
58+
59+
class KWSTransformer(MiCoModel):
60+
def __init__(
61+
self,
62+
n_classes: int = 35,
63+
input_size=(64, 81),
64+
embedding_dim: int = 96,
65+
n_conv_layers: int = 2,
66+
num_layers: int = 2,
67+
num_heads: int = 4,
68+
mlp_ratio: float = 2.0,
69+
dropout: float = 0.1,
70+
attention_dropout: float = 0.1,
71+
stochastic_depth: float = 0.0,
72+
positional_embedding: str = "learnable",
73+
):
74+
super().__init__()
75+
self.default_dataset = "SPEECHCOMMANDS_2D"
76+
self.input_size = tuple(input_size)
77+
78+
self.tokenizer = MFCCTokenizer(
79+
in_channels=1,
80+
embedding_dim=embedding_dim,
81+
n_conv_layers=n_conv_layers,
82+
)
83+
self.classifier = TransformerClassifier(
84+
sequence_length=self.tokenizer.sequence_length(self.input_size),
85+
embedding_dim=embedding_dim,
86+
seq_pool=True,
87+
dropout=dropout,
88+
attention_dropout=attention_dropout,
89+
stochastic_depth=stochastic_depth,
90+
num_layers=num_layers,
91+
num_heads=num_heads,
92+
mlp_ratio=mlp_ratio,
93+
num_classes=n_classes,
94+
positional_embedding=positional_embedding,
95+
)
96+
self.n_layers = len(self.get_qlayers())
97+
98+
def forward(self, x: torch.Tensor) -> torch.Tensor:
99+
x = self.tokenizer(x)
100+
return self.classifier(x)
101+
102+
103+
def tiny_kws_transformer(n_classes: int = 35):
104+
return KWSTransformer(
105+
n_classes=n_classes,
106+
embedding_dim=64,
107+
n_conv_layers=2,
108+
num_layers=2,
109+
num_heads=4,
110+
mlp_ratio=2.0,
111+
dropout=0.1,
112+
attention_dropout=0.1,
113+
)

0 commit comments

Comments
 (0)