Skip to content

Commit 54e6ae6

Browse files
committed
overall update around LLaMa
1 parent 0590893 commit 54e6ae6

10 files changed

Lines changed: 544 additions & 76 deletions

File tree

MiCoCodeGen.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MiCoCodeGen(torch.fx.Interpreter):
5454
extern size_t model_weight_end[];
5555
5656
// Profiler Timer
57-
extern long QMATMUL_TIMER, QUANT_TIMER, IM2COL_TIMER;
57+
extern unsigned long long QMATMUL_TIMER, QUANT_TIMER, IM2COL_TIMER;
5858
5959
typedef struct {{
6060
{model_struct}
@@ -352,20 +352,20 @@ def _format_benchmark_forward(self, indent: str):
352352
groups = self.get_benchmark_call_groups()
353353
total_occurrences = sum(group["count"] for group in groups)
354354
lines = [
355-
f"{indent}long benchmark_total_time = 0;",
355+
f"{indent}unsigned long long benchmark_total_time = 0;",
356356
f"{indent}printf(\"Benchmark Mode: %d unique kernels, %d total occurrences\\n\", {len(groups)}, {total_occurrences});",
357357
]
358358
for idx, group in enumerate(groups):
359359
escaped_name = group["function_name"].replace("\\", "\\\\").replace('"', '\\"')
360360
lines += [
361-
f"{indent}long benchmark_kernel_time_{idx} = MiCo_time();",
361+
f"{indent}unsigned long long benchmark_kernel_time_{idx} = MiCo_time();",
362362
f"{indent}{group['call']}",
363363
f"{indent}benchmark_kernel_time_{idx} = MiCo_time() - benchmark_kernel_time_{idx};",
364-
f"{indent}long benchmark_kernel_estimate_{idx} = benchmark_kernel_time_{idx} * {group['count']};",
364+
f"{indent}unsigned long long benchmark_kernel_estimate_{idx} = benchmark_kernel_time_{idx} * {group['count']};",
365365
f"{indent}benchmark_total_time += benchmark_kernel_estimate_{idx};",
366-
f"{indent}printf(\"Benchmark Kernel {idx}: {escaped_name} occurrences={group['count']} time=%ld estimated=%ld\\n\", benchmark_kernel_time_{idx}, benchmark_kernel_estimate_{idx});",
366+
f"{indent}printf(\"Benchmark Kernel {idx}: {escaped_name} occurrences={group['count']} time=%llu estimated=%llu\\n\", benchmark_kernel_time_{idx}, benchmark_kernel_estimate_{idx});",
367367
]
368-
lines.append(f"{indent}printf(\"Estimated Execution Time: %ld\\n\", benchmark_total_time);")
368+
lines.append(f"{indent}printf(\"Estimated Execution Time: %llu\\n\", benchmark_total_time);")
369369
return lines
370370

371371
def _extract_input_names(self, n: torch.fx.node.Node) -> List[str]:

MiCoEval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, model: MiCoModel | str,
4242
lr=0.0001, model_name = "",
4343
objective='ptq_acc',
4444
constraint='bops',
45-
linear_group_size = 1,
45+
linear_group_size = 0,
4646
output_json='output/json/mico_eval.json') -> None:
4747

4848
self.model = model
@@ -68,8 +68,10 @@ def __init__(self, model: MiCoModel | str,
6868

6969
self.set_eval(objective)
7070
self.set_constraint(constraint)
71-
72-
self.fp_acc = self.eval_fp()
71+
72+
self.fp_res = self.eval_fp()
73+
self.fp_acc = self.eval_fp()['TestAcc']
74+
self.fp_loss = self.eval_fp()['TestLoss']
7375
# Initial Conversion and Test
7476
res = self.eval_f([8]*self.n_layers*2)
7577
self.baseline_acc = res
@@ -96,6 +98,7 @@ def __init__(self, model: MiCoModel | str,
9698
print("Total Params: ", np.sum(self.layer_params))
9799
print("INT8 Model Accuracy: ", res)
98100
print("FP Model Accuracy: ", self.fp_acc)
101+
print("FP Model Loss:", self.fp_loss)
99102
return
100103

101104
def get_layer_info(self):
@@ -325,7 +328,7 @@ def set_mico_target(self, mico_type: str):
325328

326329
def eval_fp(self):
327330
self.model.unset_qscheme()
328-
return self.model.test(self.test_loader)['TestAcc']
331+
return self.model.test(self.test_loader)
329332

330333
def eval_ptq_loss(self, scheme: list):
331334
wq = scheme[:self.n_layers]

MiCoModel.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def get_attn_qlayers(self):
3737
def n_attn_layers(self):
3838
return len(self.get_attn_qlayers())
3939

40-
def set_qscheme(self, qscheme, qat=False, device=device, group_size = 1, use_bias = True, use_norm = False):
40+
def set_qscheme(self, qscheme, qat=False, device=None, group_size = 1, use_bias = True, use_norm = False):
41+
if device is None:
42+
device = self.current_device()
43+
self.to(device)
4144
replace_quantize_layers(self, qscheme[0], qscheme[1],
4245
quant_aware=qat, group_size=group_size,
4346
device=device, use_bias=use_bias, use_norm=use_norm)
@@ -55,22 +58,31 @@ def set_attn_qscheme(self, attn_qscheme, qat=False, **kwargs):
5558
set_to_qforward(self)
5659
return
5760

58-
def set_qscheme_torchao(self, qscheme,device=device):
61+
def set_qscheme_torchao(self, qscheme, device=None):
62+
if device is None:
63+
device = self.current_device()
64+
self.to(device)
5965
unset_qforward(self)
6066
replace_quantize_layers_torchao(self, qscheme[0], qscheme[1], device=device)
6167
return
6268

6369
def torchao_autoquant(self, example_input: torch.Tensor):
6470
raise NotImplementedError("autoquant was removed in torchao 0.17.0. Use set_qscheme_torchao() instead.")
6571

72+
def current_device(self):
73+
for tensor in list(self.parameters()) + list(self.buffers()):
74+
return tensor.device
75+
return device
76+
6677
def test(self, test_loader):
6778
self.eval()
79+
run_device = self.current_device()
6880
criterion = torch.nn.CrossEntropyLoss()
6981
test_loss = []
7082
test_total, test_correct = 0, 0
7183
with torch.no_grad():
7284
for i, (images, labels) in enumerate(test_loader):
73-
x, y = images.to(device), labels.to(device)
85+
x, y = images.to(run_device), labels.to(run_device)
7486
output = self.forward(x)
7587
_, predicted = torch.max(output.data, 1)
7688
loss = criterion(output, y)
@@ -86,6 +98,7 @@ def train_loop(self, n_epoch, train_loader, test_loader, verbose = False,
8698
warmup_epochs = 0, warmup_lr = 1e-6):
8799
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
88100
criterion = torch.nn.CrossEntropyLoss()
101+
run_device = self.current_device()
89102
warmup_epochs = max(0, int(warmup_epochs))
90103
warmup_epochs = min(warmup_epochs, max(0, n_epoch - 1))
91104
use_warmup = scheduler == "cosine" and warmup_epochs > 0
@@ -118,14 +131,14 @@ def train_loop(self, n_epoch, train_loader, test_loader, verbose = False,
118131

119132
train_loss = []
120133
train_total, train_correct = 0, 0
121-
loss = torch.tensor(np.inf)
134+
loss = torch.tensor(np.inf, device=run_device)
122135
# Training
123136
self.train()
124137
loop = tqdm(enumerate(train_loader), total=len(train_loader),
125138
disable=not verbose)
126139

127140
for i, (images, labels) in loop:
128-
x, y = images.to(device), labels.to(device)
141+
x, y = images.to(run_device), labels.to(run_device)
129142
optimizer.zero_grad()
130143
output = self(x)
131144
_, predicted = torch.max(output.data, 1)
@@ -147,7 +160,7 @@ def train_loop(self, n_epoch, train_loader, test_loader, verbose = False,
147160
test_total, test_correct = 0, 0
148161
with torch.no_grad():
149162
for i, (images, labels) in enumerate(test_loader):
150-
x, y = images.to(device), labels.to(device)
163+
x, y = images.to(run_device), labels.to(run_device)
151164
output = self.forward(x)
152165
_, predicted = torch.max(output.data, 1)
153166
loss = criterion(output, y)

MiCoQLayers.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def activation_nquant_2d(x: torch.Tensor, qbit = 8):
4141
if qbit == 1:
4242
x_absmean = torch.mean(x.abs(), dim=(-2,-1), keepdim=True)
4343
y = x.sign() * x_absmean
44+
y = torch.where(y == 0.0, -x_absmean, y)
4445
elif qbit < 2: # Ternary quantization
4546
x_absmean = torch.mean(x.abs(), dim=(-2,-1), keepdim=True)
4647
scale = 1.0 / x_absmean.clamp_(min=1e-5)
@@ -140,7 +141,7 @@ def weight_quantnb(w: torch.Tensor, qbit = 8, mode = "max"):
140141
def weight_quantnb_group(w: torch.Tensor, qbit: int = 8, mode: str = "max",
141142
dim: int = -1, group_size: int = 32, return_expanded: bool = True):
142143
"""
143-
Group-wise symmetric quantization for qbit >= 2.
144+
Group-wise weight quantization for qbit >= 1.
144145
- dim: dimension to group over
145146
- group_size: number of contiguous elements per group along 'dim'
146147
- mode: "max" (per-group max) for qbit > 2, otherwise "mean"
@@ -149,7 +150,7 @@ def weight_quantnb_group(w: torch.Tensor, qbit: int = 8, mode: str = "max",
149150
u: integer-quantized weights (same shape as w)
150151
inv_scale: inverse scale tensor (broadcastable to w if return_expanded=True)
151152
"""
152-
assert qbit > 1, "qbit should be larger than 1"
153+
assert qbit >= 1, "qbit should be larger than or equal to 1"
153154
assert isinstance(group_size, int) and group_size > 0, "group_size must be a positive int"
154155

155156
# Normalize dim to positive index
@@ -166,16 +167,25 @@ def weight_quantnb_group(w: torch.Tensor, qbit: int = 8, mode: str = "max",
166167
x = w.reshape(new_shape)
167168

168169
reduce_dim = dim + 1 # the 'group_size' axis
169-
if (mode == "max") and (qbit > 2):
170+
if qbit == 1:
171+
denom = x.abs().mean(dim=reduce_dim, keepdim=True)
172+
scale = 1.0 / denom.clamp(min=1e-5)
173+
u_group = x.sign()
174+
elif 1 < qbit < 2:
175+
denom = x.abs().mean(dim=reduce_dim, keepdim=True)
176+
scale = 1.0 / denom.clamp(min=1e-5)
177+
u_group = (x * scale).round().clamp_(-1, 1)
178+
elif (mode == "max") and (qbit > 2):
170179
denom = x.abs().amax(dim=reduce_dim, keepdim=True)
180+
scale = (2**(qbit - 1) - 1) / denom.clamp(min=1e-5)
181+
u_group = (x * scale).round().clamp_(-(2**(qbit - 1)), 2**(qbit - 1) - 1)
171182
elif (mode == "mean") or (qbit <= 2):
172183
denom = x.abs().mean(dim=reduce_dim, keepdim=True)
184+
scale = (2**(qbit - 1) - 1) / denom.clamp(min=1e-5)
185+
u_group = (x * scale).round().clamp_(-(2**(qbit - 1)), 2**(qbit - 1) - 1)
173186
else:
174187
raise ValueError("Invalid mode")
175188

176-
scale = (2**(qbit - 1) - 1) / denom.clamp(min=1e-5)
177-
178-
u_group = (x * scale).round().clamp_(-(2**(qbit - 1)), 2**(qbit - 1) - 1)
179189
u = u_group.reshape_as(w)
180190

181191
inv_scale_group = 1.0 / scale
@@ -264,6 +274,19 @@ def weight_quant(self, w: torch.Tensor):
264274

265275
def save_qweight(self):
266276
self.qw, self.qw_scale = self._weight_quant_impl(self.weight.data)
277+
278+
def qweight(self):
279+
if self.qw is None or self.qw_scale is None:
280+
self.save_qweight()
281+
if self.qw.device != self.weight.device:
282+
self.qw = self.qw.to(self.weight.device)
283+
if self.qw_scale.device != self.weight.device:
284+
self.qw_scale = self.qw_scale.to(self.weight.device)
285+
return self.qw * self.qw_scale
286+
287+
def ste_weight_quant(self):
288+
w = self.weight
289+
return w + (self.weight_quant(w) - w).detach()
267290

268291
def export_qweight(self):
269292
return {
@@ -315,14 +338,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
315338
x_norm = SimpleRMSNorm(self.in_features)(x) if self.use_norm else x
316339
x_norm = HadamardTransform()(x) if self.haramard else x_norm
317340
x_quant = x_norm + (self.act_quant(x_norm) - x_norm).detach()
318-
w_quant = w + (self.weight_quant(w) - w).detach()
341+
w_quant = self.ste_weight_quant()
319342
y = F.linear(x_quant, w_quant, bias=self.bias)
320343
return y
321344
elif self.qforward is True:
322345
# Forward with Post Training Quantization (PTQ)
323346
# Only for inference
324-
qx = self.act_quant(x)
325-
y = F.linear(qx, self.qw * self.qw_scale, bias=self.bias)
347+
x_norm = SimpleRMSNorm(self.in_features)(x) if self.use_norm else x
348+
x_norm = HadamardTransform()(x) if self.haramard else x_norm
349+
qx = self.act_quant(x_norm)
350+
y = F.linear(qx, self.qweight(), bias=self.bias)
326351
return y
327352
else:
328353
return F.linear(x, w, bias=self.bias)
@@ -387,15 +412,16 @@ def forward(self, x: torch.Tensor):
387412
# Using Straight-Through-Estimator (STE)
388413
x_norm = self.rmsnorm(x) if self.use_norm else x
389414
x_quant = x_norm + (activation_nquant_2d(x_norm, self.act_q) - x_norm).detach()
390-
w_quant = w + (self.weight_quant(w) - w).detach()
415+
w_quant = self.ste_weight_quant()
391416
y = F.conv2d(x_quant, w_quant, self.bias, self.stride, self.padding,
392417
self.dilation, self.groups)
393418
return y
394419
elif self.qforward:
395420
# Forward with Post Training Quantization (PTQ)
396421
# Only for inference
397-
qx = activation_nquant_2d(x, self.act_q)
398-
y = F.conv2d(qx, self.qw * self.qw_scale, self.bias, self.stride, self.padding,
422+
x_norm = self.rmsnorm(x) if self.use_norm else x
423+
qx = activation_nquant_2d(x_norm, self.act_q)
424+
y = F.conv2d(qx, self.qweight(), self.bias, self.stride, self.padding,
399425
self.dilation, self.groups)
400426
return y
401427
else:
@@ -449,12 +475,13 @@ def forward(self, x: torch.Tensor):
449475
if self.qat:
450476
x_norm = self.rmsnorm(x) if self.use_norm else x
451477
x_quant = x_norm + (activation_nquant(x_norm, self.act_q) - x_norm).detach()
452-
w_quant = w + (self.weight_quant(w) - w).detach()
478+
w_quant = self.ste_weight_quant()
453479
return F.conv1d(x_quant, w_quant, self.bias, self.stride,
454480
self.padding, self.dilation, self.groups)
455481
elif self.qforward is True:
456-
qx = activation_nquant(x, self.act_q)
457-
return F.conv1d(qx, self.qw * self.qw_scale, self.bias,
482+
x_norm = self.rmsnorm(x) if self.use_norm else x
483+
qx = activation_nquant(x_norm, self.act_q)
484+
return F.conv1d(qx, self.qweight(), self.bias,
458485
self.stride, self.padding, self.dilation, self.groups)
459486
else:
460487
return F.conv1d(x, w, self.bias, self.stride,

examples/attention_quant_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def parse_args():
2424
description="Evaluate one model with multiple attention quantization modes."
2525
)
2626
parser.add_argument("model_name", type=str)
27-
parser.add_argument("--weight-q", type=int, default=8)
28-
parser.add_argument("--act-q", type=int, default=8)
27+
parser.add_argument("--weight-q", type=float, default=8)
28+
parser.add_argument("--act-q", type=float, default=8)
2929
parser.add_argument(
3030
"--quant",
3131
nargs="+",

examples/mpq_gen.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def main():
8787
parser = argparse.ArgumentParser()
8888
parser.add_argument("model_name", type=str, nargs="?")
8989
parser.add_argument("--batch-size", type=int, default=1)
90+
parser.add_argument(
91+
"--num-workers",
92+
type=int,
93+
default=None,
94+
help="Override model_zoo DataLoader worker count. Use 0 in restricted sandboxes.",
95+
)
9096
parser.add_argument("--list-models", action="store_true")
9197

9298
parser.add_argument("--ckpt", type=str, default=None)
@@ -140,6 +146,9 @@ def main():
140146
if args.fuse and args.fuse_seq:
141147
raise ValueError("Please use only one of --fuse or --fuse-seq.")
142148

149+
if args.num_workers is not None:
150+
model_zoo.NUM_WORKERS = args.num_workers
151+
143152
model, _, test_loader = model_zoo.from_zoo(
144153
args.model_name, shuffle=False, batch_size=args.batch_size
145154
)

0 commit comments

Comments
 (0)