Skip to content

Commit 8a27bef

Browse files
committed
renew vit 1m arch, update bert
1 parent 6024df1 commit 8a27bef

7 files changed

Lines changed: 313 additions & 171 deletions

File tree

examples/mpq_train_bitnet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
import sys
5+
import os
56
import argparse
67

78
from models import model_zoo
@@ -10,7 +11,7 @@
1011
argsparse.add_argument("model_name", type=str)
1112
argsparse.add_argument("epoches", type=int, default=10000)
1213
argsparse.add_argument("--batch-size", type=int, default=32)
13-
argsparse.add_argument("--lr", type=float, default=0.005)
14+
argsparse.add_argument("--lr", type=float, default=0.001)
1415
argsparse.add_argument("-q", "--weight_quant", type=float, choices=[1,1.5,2], default=1)
1516
argsparse.add_argument("-aq", "--act_quant", type=int, choices=[4,8], default=8)
1617
argsparse.add_argument("--keep-last", action="store_true", default=False)
@@ -47,7 +48,13 @@
4748
qscheme[0][-1] = 8
4849
qscheme[1][-1] = 8
4950

50-
model.set_qscheme(qscheme, qat=True)
51+
model.set_qscheme(qscheme, qat=True, use_norm=True)
52+
print("Model Param Size:", sum(p.numel() for p in model.parameters()))
53+
# Detect if there is a full precision checkpoint
54+
# if os.path.exists(f"output/ckpt/{model_name}.pth"):
55+
# print("Full Precision Checkpoint Loaded.")
56+
# ckpt = torch.load(f"output/ckpt/{model_name}.pth")
57+
# model.load_state_dict(ckpt)
5158

5259
if "llama" in model_name:
5360
res = model.train_loop(n_iter=int(epoches),
@@ -68,7 +75,7 @@
6875
torch.save(model.state_dict(), f"output/ckpt/{model_name}_bitnet.pth")
6976
print("Model Train Results: ", res)
7077

71-
model.set_qscheme(qscheme)
78+
model.set_qscheme(qscheme, qat=True, use_norm=True)
7279

7380
res = model.test(test_loader)
7481

examples/vgg_cifar10.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

examples/vit_cifar10.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)