Skip to content

Commit daa2be6

Browse files
committed
add shortcut script for mico codegen
1 parent 2fabb42 commit daa2be6

2 files changed

Lines changed: 204 additions & 0 deletions

File tree

examples/mpq_gen.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
6+
from MiCoCodeGen import MiCoCodeGen
7+
from MiCoUtils import fuse_model, fuse_model_seq
8+
from models import model_zoo
9+
10+
11+
def parse_bits(bits_text: str, n_layers: int, arg_name: str):
12+
values = [int(v.strip()) for v in bits_text.split(",") if v.strip()]
13+
if len(values) == 0:
14+
raise ValueError(f"{arg_name} must contain at least one integer.")
15+
if len(values) == 1:
16+
return values * n_layers
17+
if len(values) != n_layers:
18+
raise ValueError(
19+
f"{arg_name} must provide 1 value or exactly {n_layers} values, got {len(values)}."
20+
)
21+
return values
22+
23+
24+
def parse_shape(shape_text: str):
25+
shape = tuple(int(v.strip()) for v in shape_text.split(",") if v.strip())
26+
if len(shape) == 0:
27+
raise ValueError("--example-shape must provide at least one dimension.")
28+
if any(d <= 0 for d in shape):
29+
raise ValueError("--example-shape dimensions must be positive.")
30+
return shape
31+
32+
33+
def get_batch_input(batch):
34+
if torch.is_tensor(batch):
35+
return batch
36+
37+
if isinstance(batch, (list, tuple)):
38+
if len(batch) == 0:
39+
return None
40+
if torch.is_tensor(batch[0]):
41+
return batch[0]
42+
return get_batch_input(batch[0])
43+
44+
if isinstance(batch, dict):
45+
preferred_keys = ["input_ids", "inputs", "x", "image", "images", "data"]
46+
for key in preferred_keys:
47+
value = batch.get(key)
48+
if torch.is_tensor(value):
49+
return value
50+
for value in batch.values():
51+
if torch.is_tensor(value):
52+
return value
53+
return None
54+
55+
return None
56+
57+
58+
def get_example_input(test_loader):
59+
batch = next(iter(test_loader))
60+
x = get_batch_input(batch)
61+
if x is None:
62+
raise TypeError(
63+
f"Unsupported batch type for codegen input extraction: {type(batch)}"
64+
)
65+
if x.dim() > 0:
66+
x = x[:1]
67+
return x.to("cpu")
68+
69+
70+
def load_model_ckpt(model, ckpt_path: str):
71+
ckpt = torch.load(ckpt_path, map_location="cpu")
72+
if not isinstance(ckpt, dict):
73+
raise TypeError(f"Unsupported checkpoint format: {type(ckpt)}")
74+
75+
if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
76+
ckpt = ckpt["state_dict"]
77+
elif "model_state_dict" in ckpt and isinstance(ckpt["model_state_dict"], dict):
78+
ckpt = ckpt["model_state_dict"]
79+
80+
if all(key.startswith("module.") for key in ckpt.keys()):
81+
ckpt = {key[7:]: value for key, value in ckpt.items()}
82+
83+
model.load_state_dict(ckpt)
84+
85+
86+
def main():
87+
parser = argparse.ArgumentParser()
88+
parser.add_argument("model_name", type=str, nargs="?")
89+
parser.add_argument("--batch-size", type=int, default=1)
90+
parser.add_argument("--list-models", action="store_true")
91+
92+
parser.add_argument("--ckpt", type=str, default=None)
93+
parser.add_argument("--skip-ckpt", action="store_true")
94+
95+
parser.add_argument("--weight-q", type=str, default="8")
96+
parser.add_argument("--act-q", type=str, default=None)
97+
parser.add_argument("--group-size", type=int, default=1)
98+
parser.add_argument("--skip-qscheme", action="store_true")
99+
100+
parser.add_argument("--fuse", action="store_true")
101+
parser.add_argument("--fuse-seq", action="store_true")
102+
parser.add_argument("--align-to", type=int, default=32)
103+
parser.add_argument("--gemmini-mode", action="store_true")
104+
105+
parser.add_argument("--output-dir", type=str, default="project")
106+
parser.add_argument("--output-name", type=str, default="model")
107+
parser.add_argument(
108+
"--mem-pool",
109+
action=argparse.BooleanOptionalAction,
110+
default=True,
111+
)
112+
parser.add_argument("--verbose", action="store_true")
113+
114+
parser.add_argument("--example-shape", type=str, default=None)
115+
parser.add_argument(
116+
"--example-dtype",
117+
type=str,
118+
default="float32",
119+
choices=["float32", "int64"],
120+
)
121+
parser.add_argument("--print-graph", action="store_true")
122+
parser.add_argument("--dag-file", type=str, default=None)
123+
parser.add_argument("--dag-simplified", action="store_true")
124+
125+
args = parser.parse_args()
126+
127+
if args.list_models:
128+
for name in model_zoo.list_zoo_models():
129+
print(name)
130+
return
131+
132+
if args.model_name is None:
133+
parser.error("model_name is required unless --list-models is used.")
134+
135+
if args.fuse and args.fuse_seq:
136+
raise ValueError("Please use only one of --fuse or --fuse-seq.")
137+
138+
model, _, test_loader = model_zoo.from_zoo(
139+
args.model_name, shuffle=False, batch_size=args.batch_size
140+
)
141+
model = model.to("cpu")
142+
143+
if not args.skip_ckpt:
144+
ckpt_path = args.ckpt or f"output/ckpt/{args.model_name}.pth"
145+
if not os.path.exists(ckpt_path):
146+
raise FileNotFoundError(
147+
f"Checkpoint not found: {ckpt_path}. "
148+
"Use --ckpt to set a checkpoint path or --skip-ckpt to skip loading."
149+
)
150+
load_model_ckpt(model, ckpt_path)
151+
152+
if not args.skip_qscheme:
153+
n_layers = model.n_layers
154+
weight_q = parse_bits(args.weight_q, n_layers, "--weight-q")
155+
act_q_text = args.act_q if args.act_q is not None else args.weight_q
156+
act_q = parse_bits(act_q_text, n_layers, "--act-q")
157+
model.set_qscheme([weight_q, act_q], group_size=args.group_size)
158+
159+
if args.fuse:
160+
model = fuse_model(model)
161+
elif args.fuse_seq:
162+
model = fuse_model_seq(model)
163+
164+
model.eval()
165+
166+
if args.example_shape is not None:
167+
input_shape = parse_shape(args.example_shape)
168+
if args.example_dtype == "int64":
169+
example_input = torch.zeros(input_shape, dtype=torch.int64)
170+
else:
171+
example_input = torch.randn(input_shape, dtype=torch.float32)
172+
else:
173+
if test_loader is None:
174+
raise ValueError(
175+
"No test loader available from model_zoo. "
176+
"Please provide --example-shape and --example-dtype."
177+
)
178+
example_input = get_example_input(test_loader)
179+
180+
codegen = MiCoCodeGen(
181+
model,
182+
align_to=args.align_to,
183+
gemmini_mode=args.gemmini_mode,
184+
)
185+
if args.print_graph:
186+
codegen.print_graph()
187+
codegen.forward(example_input)
188+
189+
if args.dag_file:
190+
codegen.visualize_dag(args.dag_file, simplified=args.dag_simplified)
191+
192+
codegen.convert(
193+
output_directory=args.output_dir,
194+
model_name=args.output_name,
195+
verbose=args.verbose,
196+
mem_pool=args.mem_pool,
197+
)
198+
199+
200+
if __name__ == "__main__":
201+
main()

readme.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ python examples/lenet_mnist_search.py # MPQ Search on trained LeNet
3636
# For General Script Usage
3737
python examples/mpq_train.py -h
3838
python examples/mpq_search.py -h
39+
python examples/mpq_gen.py -h
3940
```
4041

4142
**To use the CodeGen**, check the code to change the models/datasets/precisions:
4243
```shell
4344
python MiCoCodeGen.py
45+
# or use model_zoo directly
46+
python examples/mpq_gen.py lenet_mnist --output-dir project --output-name model
4447
```
4548

4649
**To compile the inference code** after generating the model header with the CodeGen:

0 commit comments

Comments
 (0)