|
| 1 | +import os |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | + |
| 8 | +''' |
| 9 | +SmoothQuant Port for MiCo Attentions |
| 10 | +https://github.com/mit-han-lab/smoothquant/ |
| 11 | +''' |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class SmoothQuantMapping: |
| 15 | + norm_name: str |
| 16 | + linear_names: Tuple[str, ...] |
| 17 | + scale_name: str |
| 18 | + reason: str |
| 19 | + |
| 20 | + |
| 21 | +def _move_to_device(obj, device): |
| 22 | + if torch.is_tensor(obj): |
| 23 | + return obj.to(device) |
| 24 | + if isinstance(obj, dict): |
| 25 | + return {key: _move_to_device(value, device) for key, value in obj.items()} |
| 26 | + if isinstance(obj, (tuple, list)): |
| 27 | + return type(obj)(_move_to_device(value, device) for value in obj) |
| 28 | + return obj |
| 29 | + |
| 30 | + |
| 31 | +def default_forward_batch(model: nn.Module, batch, device): |
| 32 | + batch = _move_to_device(batch, device) |
| 33 | + if torch.is_tensor(batch): |
| 34 | + return model(batch) |
| 35 | + if isinstance(batch, dict): |
| 36 | + return model(**batch) |
| 37 | + if isinstance(batch, (tuple, list)): |
| 38 | + if len(batch) == 0: |
| 39 | + raise ValueError("Empty calibration batch.") |
| 40 | + return model(batch[0]) |
| 41 | + raise TypeError(f"Unsupported calibration batch type: {type(batch)!r}") |
| 42 | + |
| 43 | + |
| 44 | +def collect_smoothquant_act_scales( |
| 45 | + model: nn.Module, |
| 46 | + calib_loader: Iterable, |
| 47 | + num_batches: Optional[int] = None, |
| 48 | + device: Optional[torch.device] = None, |
| 49 | + forward_batch: Optional[Callable[[nn.Module, object, torch.device], object]] = None, |
| 50 | +) -> Dict[str, torch.Tensor]: |
| 51 | + """ |
| 52 | + Collect per-input-channel activation max values for every nn.Linear module. |
| 53 | +
|
| 54 | + The returned dict maps module names to tensors of shape [in_features]. This |
| 55 | + mirrors the core calibration used by SmoothQuant but is self-contained and |
| 56 | + works with this repository's local models. |
| 57 | + """ |
| 58 | + was_training = model.training |
| 59 | + model.eval() |
| 60 | + if device is None: |
| 61 | + device = next(model.parameters()).device |
| 62 | + if forward_batch is None: |
| 63 | + forward_batch = default_forward_batch |
| 64 | + |
| 65 | + act_scales: Dict[str, torch.Tensor] = {} |
| 66 | + hooks = [] |
| 67 | + |
| 68 | + def stat_input(name, module, inputs): |
| 69 | + if not inputs: |
| 70 | + return |
| 71 | + x = inputs[0] |
| 72 | + if not torch.is_tensor(x) or x.numel() == 0: |
| 73 | + return |
| 74 | + if x.shape[-1] != module.in_features: |
| 75 | + return |
| 76 | + x_absmax = x.detach().reshape(-1, x.shape[-1]).abs().amax(dim=0).float().cpu() |
| 77 | + if name in act_scales: |
| 78 | + act_scales[name] = torch.maximum(act_scales[name], x_absmax) |
| 79 | + else: |
| 80 | + act_scales[name] = x_absmax |
| 81 | + |
| 82 | + for name, module in model.named_modules(): |
| 83 | + if isinstance(module, nn.Linear): |
| 84 | + hooks.append(module.register_forward_pre_hook( |
| 85 | + lambda m, inputs, name=name: stat_input(name, m, inputs) |
| 86 | + )) |
| 87 | + |
| 88 | + try: |
| 89 | + with torch.no_grad(): |
| 90 | + for batch_idx, batch in enumerate(calib_loader): |
| 91 | + if num_batches is not None and batch_idx >= num_batches: |
| 92 | + break |
| 93 | + forward_batch(model, batch, device) |
| 94 | + finally: |
| 95 | + for hook in hooks: |
| 96 | + hook.remove() |
| 97 | + model.train(was_training) |
| 98 | + |
| 99 | + return act_scales |
| 100 | + |
| 101 | + |
| 102 | +@torch.no_grad() |
| 103 | +def smooth_norm_linears( |
| 104 | + norm: nn.Module, |
| 105 | + linears: Sequence[nn.Linear], |
| 106 | + act_scales: torch.Tensor, |
| 107 | + alpha: float = 0.5, |
| 108 | + eps: float = 1e-5, |
| 109 | +) -> torch.Tensor: |
| 110 | + """ |
| 111 | + Fold SmoothQuant scales into a normalization module and one or more |
| 112 | + following linear layers. |
| 113 | +
|
| 114 | + For x_norm = norm(x), Linear(x_norm) is preserved by applying: |
| 115 | + norm.weight /= scale, norm.bias /= scale, linear.weight *= scale |
| 116 | + """ |
| 117 | + if not 0.0 <= alpha <= 1.0: |
| 118 | + raise ValueError(f"alpha must be in [0, 1], got {alpha}") |
| 119 | + if not linears: |
| 120 | + raise ValueError("linears must contain at least one module") |
| 121 | + if not hasattr(norm, "weight") or norm.weight is None: |
| 122 | + raise ValueError(f"{norm.__class__.__name__} has no affine weight to absorb SmoothQuant scales") |
| 123 | + |
| 124 | + in_features = linears[0].in_features |
| 125 | + if norm.weight.numel() != in_features: |
| 126 | + raise ValueError( |
| 127 | + f"Norm/Linear shape mismatch: norm={norm.weight.numel()} linear.in_features={in_features}" |
| 128 | + ) |
| 129 | + for linear in linears: |
| 130 | + if not isinstance(linear, nn.Linear): |
| 131 | + raise TypeError(f"Expected nn.Linear, got {type(linear)!r}") |
| 132 | + if linear.in_features != in_features: |
| 133 | + raise ValueError("All smoothed linears must share the same input feature size") |
| 134 | + |
| 135 | + device = linears[0].weight.device |
| 136 | + dtype = linears[0].weight.dtype |
| 137 | + act_scales = act_scales.to(device=device, dtype=dtype).clamp(min=eps) |
| 138 | + if act_scales.numel() != in_features: |
| 139 | + raise ValueError( |
| 140 | + f"act_scales size mismatch: got {act_scales.numel()}, expected {in_features}" |
| 141 | + ) |
| 142 | + |
| 143 | + weight_scales = torch.stack( |
| 144 | + [linear.weight.detach().abs().amax(dim=0).to(device=device, dtype=dtype) for linear in linears], |
| 145 | + dim=0, |
| 146 | + ).amax(dim=0).clamp(min=eps) |
| 147 | + scales = (act_scales.pow(alpha) / weight_scales.pow(1.0 - alpha)).clamp(min=eps) |
| 148 | + |
| 149 | + norm.weight.div_(scales.to(device=norm.weight.device, dtype=norm.weight.dtype)) |
| 150 | + if getattr(norm, "bias", None) is not None: |
| 151 | + norm.bias.div_(scales.to(device=norm.bias.device, dtype=norm.bias.dtype)) |
| 152 | + for linear in linears: |
| 153 | + linear.weight.mul_(scales.to(device=linear.weight.device, dtype=linear.weight.dtype).view(1, -1)) |
| 154 | + |
| 155 | + return scales.detach().cpu() |
| 156 | + |
| 157 | + |
| 158 | +def _join_name(prefix: str, child: str) -> str: |
| 159 | + return child if prefix == "" else f"{prefix}.{child}" |
| 160 | + |
| 161 | + |
| 162 | +def _has_named_modules(modules: Dict[str, nn.Module], names: Sequence[str]) -> bool: |
| 163 | + return all(name in modules for name in names) |
| 164 | + |
| 165 | + |
| 166 | +def find_smoothquant_mappings(model: nn.Module) -> List[SmoothQuantMapping]: |
| 167 | + """ |
| 168 | + Find safe norm -> linear groups for local pre-norm Transformer-style models. |
| 169 | +
|
| 170 | + Supported automatic patterns: |
| 171 | + - models.LLaMa TransformerBlock |
| 172 | + - models.CCT TransformerEncoderLayer attention path |
| 173 | + - models.ViT TransformerEncoder |
| 174 | + """ |
| 175 | + modules = dict(model.named_modules()) |
| 176 | + mappings: List[SmoothQuantMapping] = [] |
| 177 | + |
| 178 | + for prefix, module in model.named_modules(): |
| 179 | + # LLaMa.TransformerBlock: attention_norm -> wq/wk/wv, |
| 180 | + # ffn_norm -> w1/w3. w2 is not directly fed by the norm output. |
| 181 | + llama_attn = [ |
| 182 | + _join_name(prefix, "attention_norm"), |
| 183 | + _join_name(prefix, "attention.wq"), |
| 184 | + _join_name(prefix, "attention.wk"), |
| 185 | + _join_name(prefix, "attention.wv"), |
| 186 | + ] |
| 187 | + llama_ffn = [ |
| 188 | + _join_name(prefix, "ffn_norm"), |
| 189 | + _join_name(prefix, "feed_forward.w1"), |
| 190 | + _join_name(prefix, "feed_forward.w3"), |
| 191 | + ] |
| 192 | + if _has_named_modules(modules, llama_attn): |
| 193 | + mappings.append(SmoothQuantMapping( |
| 194 | + norm_name=llama_attn[0], |
| 195 | + linear_names=tuple(llama_attn[1:]), |
| 196 | + scale_name=llama_attn[1], |
| 197 | + reason="llama_attention_qkv", |
| 198 | + )) |
| 199 | + if _has_named_modules(modules, llama_ffn): |
| 200 | + mappings.append(SmoothQuantMapping( |
| 201 | + norm_name=llama_ffn[0], |
| 202 | + linear_names=tuple(llama_ffn[1:]), |
| 203 | + scale_name=llama_ffn[1], |
| 204 | + reason="llama_ffn_gate_up", |
| 205 | + )) |
| 206 | + |
| 207 | + # CCT.TransformerEncoderLayer: pre_norm -> q/k/v. Do not smooth |
| 208 | + # norm1 -> linear1 because norm1 is post-norm and its output is also |
| 209 | + # the residual base in the FFN block. |
| 210 | + cct_attn = [ |
| 211 | + _join_name(prefix, "pre_norm"), |
| 212 | + _join_name(prefix, "self_attn.q"), |
| 213 | + _join_name(prefix, "self_attn.k"), |
| 214 | + _join_name(prefix, "self_attn.v"), |
| 215 | + ] |
| 216 | + if _has_named_modules(modules, cct_attn): |
| 217 | + mappings.append(SmoothQuantMapping( |
| 218 | + norm_name=cct_attn[0], |
| 219 | + linear_names=tuple(cct_attn[1:]), |
| 220 | + scale_name=cct_attn[1], |
| 221 | + reason="cct_attention_qkv", |
| 222 | + )) |
| 223 | + |
| 224 | + # ViT.TransformerEncoder: la1 -> q/k/v, la2 -> first MLP linear. |
| 225 | + vit_attn = [ |
| 226 | + _join_name(prefix, "la1"), |
| 227 | + _join_name(prefix, "msa.q"), |
| 228 | + _join_name(prefix, "msa.k"), |
| 229 | + _join_name(prefix, "msa.v"), |
| 230 | + ] |
| 231 | + vit_ffn = [ |
| 232 | + _join_name(prefix, "la2"), |
| 233 | + _join_name(prefix, "mlp.0"), |
| 234 | + ] |
| 235 | + if _has_named_modules(modules, vit_attn): |
| 236 | + mappings.append(SmoothQuantMapping( |
| 237 | + norm_name=vit_attn[0], |
| 238 | + linear_names=tuple(vit_attn[1:]), |
| 239 | + scale_name=vit_attn[1], |
| 240 | + reason="vit_attention_qkv", |
| 241 | + )) |
| 242 | + if _has_named_modules(modules, vit_ffn): |
| 243 | + mappings.append(SmoothQuantMapping( |
| 244 | + norm_name=vit_ffn[0], |
| 245 | + linear_names=(vit_ffn[1],), |
| 246 | + scale_name=vit_ffn[1], |
| 247 | + reason="vit_ffn_linear1", |
| 248 | + )) |
| 249 | + |
| 250 | + deduped = [] |
| 251 | + seen = set() |
| 252 | + for mapping in mappings: |
| 253 | + key = (mapping.norm_name, mapping.linear_names, mapping.scale_name) |
| 254 | + if key not in seen: |
| 255 | + deduped.append(mapping) |
| 256 | + seen.add(key) |
| 257 | + return deduped |
| 258 | + |
| 259 | + |
| 260 | +@torch.no_grad() |
| 261 | +def apply_smoothquant( |
| 262 | + model: nn.Module, |
| 263 | + act_scales: Dict[str, torch.Tensor], |
| 264 | + alpha: float = 0.5, |
| 265 | + mappings: Optional[Sequence[SmoothQuantMapping]] = None, |
| 266 | + strict: bool = False, |
| 267 | + verbose: bool = False, |
| 268 | +) -> Dict[str, torch.Tensor]: |
| 269 | + """ |
| 270 | + Apply SmoothQuant smoothing to supported norm -> linear groups. |
| 271 | +
|
| 272 | + This should run before MiCo set_qscheme()/PTQ layer replacement. |
| 273 | + Returns the per-group smoothing scales keyed by scale_name. |
| 274 | + """ |
| 275 | + modules = dict(model.named_modules()) |
| 276 | + if mappings is None: |
| 277 | + mappings = find_smoothquant_mappings(model) |
| 278 | + |
| 279 | + applied: Dict[str, torch.Tensor] = {} |
| 280 | + for mapping in mappings: |
| 281 | + missing = [ |
| 282 | + name for name in (mapping.norm_name, *mapping.linear_names, mapping.scale_name) |
| 283 | + if name not in modules and name not in act_scales |
| 284 | + ] |
| 285 | + if missing: |
| 286 | + if strict: |
| 287 | + raise KeyError(f"Missing SmoothQuant modules/scales for {mapping}: {missing}") |
| 288 | + continue |
| 289 | + if mapping.norm_name not in modules or any(name not in modules for name in mapping.linear_names): |
| 290 | + if strict: |
| 291 | + raise KeyError(f"Missing module for SmoothQuant mapping: {mapping}") |
| 292 | + continue |
| 293 | + if mapping.scale_name not in act_scales: |
| 294 | + if strict: |
| 295 | + raise KeyError(f"Missing activation scale for {mapping.scale_name}") |
| 296 | + continue |
| 297 | + |
| 298 | + norm = modules[mapping.norm_name] |
| 299 | + linears = [modules[name] for name in mapping.linear_names] |
| 300 | + try: |
| 301 | + applied[mapping.scale_name] = smooth_norm_linears( |
| 302 | + norm=norm, |
| 303 | + linears=linears, |
| 304 | + act_scales=act_scales[mapping.scale_name], |
| 305 | + alpha=alpha, |
| 306 | + ) |
| 307 | + except (TypeError, ValueError) as exc: |
| 308 | + if strict: |
| 309 | + raise |
| 310 | + if verbose: |
| 311 | + print(f"[SmoothQuant] skipped {mapping.scale_name}: {exc}") |
| 312 | + continue |
| 313 | + if verbose: |
| 314 | + print( |
| 315 | + f"[SmoothQuant] {mapping.reason}: {mapping.norm_name} -> " |
| 316 | + f"{', '.join(mapping.linear_names)}" |
| 317 | + ) |
| 318 | + |
| 319 | + return applied |
| 320 | + |
| 321 | + |
| 322 | +def save_act_scales(act_scales: Dict[str, torch.Tensor], path: str): |
| 323 | + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
| 324 | + torch.save({key: value.cpu() for key, value in act_scales.items()}, path) |
| 325 | + |
| 326 | + |
| 327 | +def load_act_scales(path: str) -> Dict[str, torch.Tensor]: |
| 328 | + return torch.load(path, map_location="cpu") |
0 commit comments