Skip to content

Commit 765883a

Browse files
committed
add SmoothQuant Support
1 parent bf25177 commit 765883a

2 files changed

Lines changed: 398 additions & 29 deletions

File tree

MiCoSmoothQuant.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import os
2+
from dataclasses import dataclass
3+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
'''
9+
SmoothQuant Port for MiCo Attentions
10+
https://github.com/mit-han-lab/smoothquant/
11+
'''
12+
13+
@dataclass
14+
class SmoothQuantMapping:
15+
norm_name: str
16+
linear_names: Tuple[str, ...]
17+
scale_name: str
18+
reason: str
19+
20+
21+
def _move_to_device(obj, device):
22+
if torch.is_tensor(obj):
23+
return obj.to(device)
24+
if isinstance(obj, dict):
25+
return {key: _move_to_device(value, device) for key, value in obj.items()}
26+
if isinstance(obj, (tuple, list)):
27+
return type(obj)(_move_to_device(value, device) for value in obj)
28+
return obj
29+
30+
31+
def default_forward_batch(model: nn.Module, batch, device):
32+
batch = _move_to_device(batch, device)
33+
if torch.is_tensor(batch):
34+
return model(batch)
35+
if isinstance(batch, dict):
36+
return model(**batch)
37+
if isinstance(batch, (tuple, list)):
38+
if len(batch) == 0:
39+
raise ValueError("Empty calibration batch.")
40+
return model(batch[0])
41+
raise TypeError(f"Unsupported calibration batch type: {type(batch)!r}")
42+
43+
44+
def collect_smoothquant_act_scales(
45+
model: nn.Module,
46+
calib_loader: Iterable,
47+
num_batches: Optional[int] = None,
48+
device: Optional[torch.device] = None,
49+
forward_batch: Optional[Callable[[nn.Module, object, torch.device], object]] = None,
50+
) -> Dict[str, torch.Tensor]:
51+
"""
52+
Collect per-input-channel activation max values for every nn.Linear module.
53+
54+
The returned dict maps module names to tensors of shape [in_features]. This
55+
mirrors the core calibration used by SmoothQuant but is self-contained and
56+
works with this repository's local models.
57+
"""
58+
was_training = model.training
59+
model.eval()
60+
if device is None:
61+
device = next(model.parameters()).device
62+
if forward_batch is None:
63+
forward_batch = default_forward_batch
64+
65+
act_scales: Dict[str, torch.Tensor] = {}
66+
hooks = []
67+
68+
def stat_input(name, module, inputs):
69+
if not inputs:
70+
return
71+
x = inputs[0]
72+
if not torch.is_tensor(x) or x.numel() == 0:
73+
return
74+
if x.shape[-1] != module.in_features:
75+
return
76+
x_absmax = x.detach().reshape(-1, x.shape[-1]).abs().amax(dim=0).float().cpu()
77+
if name in act_scales:
78+
act_scales[name] = torch.maximum(act_scales[name], x_absmax)
79+
else:
80+
act_scales[name] = x_absmax
81+
82+
for name, module in model.named_modules():
83+
if isinstance(module, nn.Linear):
84+
hooks.append(module.register_forward_pre_hook(
85+
lambda m, inputs, name=name: stat_input(name, m, inputs)
86+
))
87+
88+
try:
89+
with torch.no_grad():
90+
for batch_idx, batch in enumerate(calib_loader):
91+
if num_batches is not None and batch_idx >= num_batches:
92+
break
93+
forward_batch(model, batch, device)
94+
finally:
95+
for hook in hooks:
96+
hook.remove()
97+
model.train(was_training)
98+
99+
return act_scales
100+
101+
102+
@torch.no_grad()
103+
def smooth_norm_linears(
104+
norm: nn.Module,
105+
linears: Sequence[nn.Linear],
106+
act_scales: torch.Tensor,
107+
alpha: float = 0.5,
108+
eps: float = 1e-5,
109+
) -> torch.Tensor:
110+
"""
111+
Fold SmoothQuant scales into a normalization module and one or more
112+
following linear layers.
113+
114+
For x_norm = norm(x), Linear(x_norm) is preserved by applying:
115+
norm.weight /= scale, norm.bias /= scale, linear.weight *= scale
116+
"""
117+
if not 0.0 <= alpha <= 1.0:
118+
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
119+
if not linears:
120+
raise ValueError("linears must contain at least one module")
121+
if not hasattr(norm, "weight") or norm.weight is None:
122+
raise ValueError(f"{norm.__class__.__name__} has no affine weight to absorb SmoothQuant scales")
123+
124+
in_features = linears[0].in_features
125+
if norm.weight.numel() != in_features:
126+
raise ValueError(
127+
f"Norm/Linear shape mismatch: norm={norm.weight.numel()} linear.in_features={in_features}"
128+
)
129+
for linear in linears:
130+
if not isinstance(linear, nn.Linear):
131+
raise TypeError(f"Expected nn.Linear, got {type(linear)!r}")
132+
if linear.in_features != in_features:
133+
raise ValueError("All smoothed linears must share the same input feature size")
134+
135+
device = linears[0].weight.device
136+
dtype = linears[0].weight.dtype
137+
act_scales = act_scales.to(device=device, dtype=dtype).clamp(min=eps)
138+
if act_scales.numel() != in_features:
139+
raise ValueError(
140+
f"act_scales size mismatch: got {act_scales.numel()}, expected {in_features}"
141+
)
142+
143+
weight_scales = torch.stack(
144+
[linear.weight.detach().abs().amax(dim=0).to(device=device, dtype=dtype) for linear in linears],
145+
dim=0,
146+
).amax(dim=0).clamp(min=eps)
147+
scales = (act_scales.pow(alpha) / weight_scales.pow(1.0 - alpha)).clamp(min=eps)
148+
149+
norm.weight.div_(scales.to(device=norm.weight.device, dtype=norm.weight.dtype))
150+
if getattr(norm, "bias", None) is not None:
151+
norm.bias.div_(scales.to(device=norm.bias.device, dtype=norm.bias.dtype))
152+
for linear in linears:
153+
linear.weight.mul_(scales.to(device=linear.weight.device, dtype=linear.weight.dtype).view(1, -1))
154+
155+
return scales.detach().cpu()
156+
157+
158+
def _join_name(prefix: str, child: str) -> str:
159+
return child if prefix == "" else f"{prefix}.{child}"
160+
161+
162+
def _has_named_modules(modules: Dict[str, nn.Module], names: Sequence[str]) -> bool:
163+
return all(name in modules for name in names)
164+
165+
166+
def find_smoothquant_mappings(model: nn.Module) -> List[SmoothQuantMapping]:
167+
"""
168+
Find safe norm -> linear groups for local pre-norm Transformer-style models.
169+
170+
Supported automatic patterns:
171+
- models.LLaMa TransformerBlock
172+
- models.CCT TransformerEncoderLayer attention path
173+
- models.ViT TransformerEncoder
174+
"""
175+
modules = dict(model.named_modules())
176+
mappings: List[SmoothQuantMapping] = []
177+
178+
for prefix, module in model.named_modules():
179+
# LLaMa.TransformerBlock: attention_norm -> wq/wk/wv,
180+
# ffn_norm -> w1/w3. w2 is not directly fed by the norm output.
181+
llama_attn = [
182+
_join_name(prefix, "attention_norm"),
183+
_join_name(prefix, "attention.wq"),
184+
_join_name(prefix, "attention.wk"),
185+
_join_name(prefix, "attention.wv"),
186+
]
187+
llama_ffn = [
188+
_join_name(prefix, "ffn_norm"),
189+
_join_name(prefix, "feed_forward.w1"),
190+
_join_name(prefix, "feed_forward.w3"),
191+
]
192+
if _has_named_modules(modules, llama_attn):
193+
mappings.append(SmoothQuantMapping(
194+
norm_name=llama_attn[0],
195+
linear_names=tuple(llama_attn[1:]),
196+
scale_name=llama_attn[1],
197+
reason="llama_attention_qkv",
198+
))
199+
if _has_named_modules(modules, llama_ffn):
200+
mappings.append(SmoothQuantMapping(
201+
norm_name=llama_ffn[0],
202+
linear_names=tuple(llama_ffn[1:]),
203+
scale_name=llama_ffn[1],
204+
reason="llama_ffn_gate_up",
205+
))
206+
207+
# CCT.TransformerEncoderLayer: pre_norm -> q/k/v. Do not smooth
208+
# norm1 -> linear1 because norm1 is post-norm and its output is also
209+
# the residual base in the FFN block.
210+
cct_attn = [
211+
_join_name(prefix, "pre_norm"),
212+
_join_name(prefix, "self_attn.q"),
213+
_join_name(prefix, "self_attn.k"),
214+
_join_name(prefix, "self_attn.v"),
215+
]
216+
if _has_named_modules(modules, cct_attn):
217+
mappings.append(SmoothQuantMapping(
218+
norm_name=cct_attn[0],
219+
linear_names=tuple(cct_attn[1:]),
220+
scale_name=cct_attn[1],
221+
reason="cct_attention_qkv",
222+
))
223+
224+
# ViT.TransformerEncoder: la1 -> q/k/v, la2 -> first MLP linear.
225+
vit_attn = [
226+
_join_name(prefix, "la1"),
227+
_join_name(prefix, "msa.q"),
228+
_join_name(prefix, "msa.k"),
229+
_join_name(prefix, "msa.v"),
230+
]
231+
vit_ffn = [
232+
_join_name(prefix, "la2"),
233+
_join_name(prefix, "mlp.0"),
234+
]
235+
if _has_named_modules(modules, vit_attn):
236+
mappings.append(SmoothQuantMapping(
237+
norm_name=vit_attn[0],
238+
linear_names=tuple(vit_attn[1:]),
239+
scale_name=vit_attn[1],
240+
reason="vit_attention_qkv",
241+
))
242+
if _has_named_modules(modules, vit_ffn):
243+
mappings.append(SmoothQuantMapping(
244+
norm_name=vit_ffn[0],
245+
linear_names=(vit_ffn[1],),
246+
scale_name=vit_ffn[1],
247+
reason="vit_ffn_linear1",
248+
))
249+
250+
deduped = []
251+
seen = set()
252+
for mapping in mappings:
253+
key = (mapping.norm_name, mapping.linear_names, mapping.scale_name)
254+
if key not in seen:
255+
deduped.append(mapping)
256+
seen.add(key)
257+
return deduped
258+
259+
260+
@torch.no_grad()
261+
def apply_smoothquant(
262+
model: nn.Module,
263+
act_scales: Dict[str, torch.Tensor],
264+
alpha: float = 0.5,
265+
mappings: Optional[Sequence[SmoothQuantMapping]] = None,
266+
strict: bool = False,
267+
verbose: bool = False,
268+
) -> Dict[str, torch.Tensor]:
269+
"""
270+
Apply SmoothQuant smoothing to supported norm -> linear groups.
271+
272+
This should run before MiCo set_qscheme()/PTQ layer replacement.
273+
Returns the per-group smoothing scales keyed by scale_name.
274+
"""
275+
modules = dict(model.named_modules())
276+
if mappings is None:
277+
mappings = find_smoothquant_mappings(model)
278+
279+
applied: Dict[str, torch.Tensor] = {}
280+
for mapping in mappings:
281+
missing = [
282+
name for name in (mapping.norm_name, *mapping.linear_names, mapping.scale_name)
283+
if name not in modules and name not in act_scales
284+
]
285+
if missing:
286+
if strict:
287+
raise KeyError(f"Missing SmoothQuant modules/scales for {mapping}: {missing}")
288+
continue
289+
if mapping.norm_name not in modules or any(name not in modules for name in mapping.linear_names):
290+
if strict:
291+
raise KeyError(f"Missing module for SmoothQuant mapping: {mapping}")
292+
continue
293+
if mapping.scale_name not in act_scales:
294+
if strict:
295+
raise KeyError(f"Missing activation scale for {mapping.scale_name}")
296+
continue
297+
298+
norm = modules[mapping.norm_name]
299+
linears = [modules[name] for name in mapping.linear_names]
300+
try:
301+
applied[mapping.scale_name] = smooth_norm_linears(
302+
norm=norm,
303+
linears=linears,
304+
act_scales=act_scales[mapping.scale_name],
305+
alpha=alpha,
306+
)
307+
except (TypeError, ValueError) as exc:
308+
if strict:
309+
raise
310+
if verbose:
311+
print(f"[SmoothQuant] skipped {mapping.scale_name}: {exc}")
312+
continue
313+
if verbose:
314+
print(
315+
f"[SmoothQuant] {mapping.reason}: {mapping.norm_name} -> "
316+
f"{', '.join(mapping.linear_names)}"
317+
)
318+
319+
return applied
320+
321+
322+
def save_act_scales(act_scales: Dict[str, torch.Tensor], path: str):
323+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
324+
torch.save({key: value.cpu() for key, value in act_scales.items()}, path)
325+
326+
327+
def load_act_scales(path: str) -> Dict[str, torch.Tensor]:
328+
return torch.load(path, map_location="cpu")

0 commit comments

Comments
 (0)