|
2 | 2 | import numpy as np |
3 | 3 |
|
4 | 4 | import sys |
| 5 | +import os |
5 | 6 | import argparse |
6 | 7 |
|
7 | 8 | from models import model_zoo |
|
10 | 11 | argsparse.add_argument("model_name", type=str) |
11 | 12 | argsparse.add_argument("epoches", type=int, default=10000) |
12 | 13 | argsparse.add_argument("--batch-size", type=int, default=32) |
13 | | -argsparse.add_argument("--lr", type=float, default=0.005) |
| 14 | +argsparse.add_argument("--lr", type=float, default=0.001) |
14 | 15 | argsparse.add_argument("-q", "--weight_quant", type=float, choices=[1,1.5,2], default=1) |
15 | 16 | argsparse.add_argument("-aq", "--act_quant", type=int, choices=[4,8], default=8) |
16 | 17 | argsparse.add_argument("--keep-last", action="store_true", default=False) |
|
47 | 48 | qscheme[0][-1] = 8 |
48 | 49 | qscheme[1][-1] = 8 |
49 | 50 |
|
50 | | - model.set_qscheme(qscheme, qat=True) |
| 51 | + model.set_qscheme(qscheme, qat=True, use_norm=True) |
| 52 | + print("Model Param Size:", sum(p.numel() for p in model.parameters())) |
| 53 | + # Detect if there is a full precision checkpoint |
| 54 | + # if os.path.exists(f"output/ckpt/{model_name}.pth"): |
| 55 | + # print("Full Precision Checkpoint Loaded.") |
| 56 | + # ckpt = torch.load(f"output/ckpt/{model_name}.pth") |
| 57 | + # model.load_state_dict(ckpt) |
51 | 58 |
|
52 | 59 | if "llama" in model_name: |
53 | 60 | res = model.train_loop(n_iter=int(epoches), |
|
68 | 75 | torch.save(model.state_dict(), f"output/ckpt/{model_name}_bitnet.pth") |
69 | 76 | print("Model Train Results: ", res) |
70 | 77 |
|
71 | | - model.set_qscheme(qscheme) |
| 78 | + model.set_qscheme(qscheme, qat=True, use_norm=True) |
72 | 79 |
|
73 | 80 | res = model.test(test_loader) |
74 | 81 |
|
|
0 commit comments