Skip to content
This repository was archived by the owner on May 5, 2024. It is now read-only.

Commit f31299e

Browse files
committed
update readme
update readme update readme update readme update readme update readme
1 parent e1c02e6 commit f31299e

9 files changed

Lines changed: 224 additions & 85 deletions

File tree

README.md

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,170 @@
1-
# bragg_hls
2-
Low-latency Bragg peak detection through high-level synthesis
1+
# BraggHLS
32

4-
# Requirements
3+
- [BraggHLS](#bragghls)
4+
- [Current status](#current-status)
5+
- [Building](#building)
6+
* [Requirements](#requirements)
7+
* [Build steps](#build-steps)
8+
- [Running](#running)
59

6-
`sudo apt-get install libgmp3-dev`
7-
`sudo apt-get install libmpfr-dev libmpfi-dev`
10+
This a framework for lowering PyTorch models to RTL using high-level synthesis (HLS) techniques.
11+
Crucially, we do **not** use any existing HLS tools (such as Xilinx's Vitis)
12+
The particular, driving, use case is low-latency [Bragg peak detection](https://arxiv.org/abs/2008.08198) for high-energy diffraction microscopy (HEDM).
13+
14+
The "flow" is PyTorch -> MLIR -> python -> MLIR -> RTL.
15+
16+
This project has a lot of moving pieces; the directory structure tells the tale:
17+
18+
- [bragghls/](bragghls) - the core python library
19+
- [compiler/compiler.py](bragghls/compiler.py) - python script the drives the entire flow
20+
- [flopoco/](bragghls/flopoco) - functionality related to converting to and from [FloPoCo's](http://flopoco.org/) nonstandard floating point representation (for purposes of RTL generation *and* simulation)
21+
- [ir/](bragghls/ir) - functionality related to parsing, transforming, and interpreting MLIR representations of PyTorch models.
22+
- [rtl/](bragghls/rtl) - functionality related to emitting RTL (SystemVerilog)
23+
- [testbench/](bragghls/testbench) - testbench runners via [cocotb](https://www.cocotb.org/) and [iverilog](http://iverilog.icarus.com/)
24+
- [bragghls_translate/](bragghls_translate) - MLIR translation library for translating MLIR to python
25+
- [examples/](examples) - obviously...
26+
- [ip_cores/](ip_cores) - FloPoCo cores for 4,4 and 5,5 floating point addition and multiplication along with testbench generation script
27+
- [flopoco_convert_ext/](ip_cores/flopoco_convert_ext) - pybind-ed extension for converting between IEEE754 and FloPoCo's floating point representation
28+
- [scripts/](scripts) - helper scripts for things like generating new FloPoCo IPs and building the entire project
29+
- [tests/](tests) - obviously...
30+
31+
# Current status
32+
33+
[linear](examples/linear.py) and [cnn](examples/cnn.py) examples work (including tiling) but [braggnn](examples/braggnn.py) still needs adjustment (compiles but doesn't pass tests).
34+
35+
# Building
36+
37+
The build steps are many and tortuous.
38+
39+
## Requirements
40+
41+
1. A compiler (GCC or Clang)
42+
2. Python (>= 3.10) (recommended to use conda)
43+
3. [GNU MP Bignum Library](https://gmplib.org/) (`sudo apt-get install libgmp3-dev`)
44+
4. [GNU Multiple Precision Floating-Point Reliable Library](https://www.mpfr.org/) (`sudo apt-get install libmpfr-dev libmpfi-dev`)
45+
4. [Multiple Precision Floating-point Interval library](http://perso.ens-lyon.fr/nathalie.revol/software.html) (`sudo apt-get libmpfi-dev`)
46+
5. [Icarus Verilog](http://iverilog.icarus.com/) (`sudo apt-get install iverilog`)
47+
5. Patience
48+
49+
Everything else should be taken care of by the build script (if I didn't miss anything...).
50+
51+
## Build steps
52+
53+
1. First make sure you have all the submodules checked out
54+
```shell
55+
git submodule sync --recursive
56+
git submodule update --init --recursive --jobs 0
57+
```
58+
This will take a while due to our dependency on LLVM.
59+
2. `pip install -r requirements.txt` to get `cmake` and `pybind11` and `ninja` and necessary python packages
60+
3. Run the build script [scripts/build.sh](scripts/build.sh) which will:
61+
1. Build all of LLVM
62+
2. Build Torch-MLIR against LLVM
63+
3. Build CIRCT against LLVM
64+
4. Build `bragghls_translate` and `flopoco_converter`
65+
5. Download GHDL and unpack it (this step is optional if you don't want to generate new IP)
66+
67+
You will need all the relevant executables (`circt-opt`, `torch-mlir-opt`, etc.) in your path **and in an env variable BRAGGHLS_PATH**. See [.envrc](.envrc) for a way to add all of them (or just use [direnv](https://direnv.net/)).
68+
You will also need the following environment variables exported:
69+
70+
```shell
71+
export ADD_PIPELINE_DEPTH=2
72+
export MUL_PIPELINE_DEPTH=1
73+
export WE=4
74+
export WF=4
75+
export PYTHONPATH=<BraggHLS source directory>
76+
```
77+
78+
The above are the correct numbers for the 4,4 FloPoCo IP cores.
79+
80+
# Running
81+
82+
Assuming everything built successfully and you have all of the correct paths and environment variables, run any of the scripts in [examples](examples) to generate MLIR IR.
83+
Then the main [compiler driver](bragghls/compiler/compile.py) can be run with the following arguments
84+
85+
```shell
86+
usage: BraggHLS compiler driver [-h] [-t] [-r] [-s] [-v] [-b] [--wE WE] [--wF WF] fp
87+
88+
positional arguments:
89+
fp Filepath of top-level MLIR file
90+
91+
options:
92+
-h, --help show this help message and exit
93+
-t, --translate Translate MLIR to python
94+
-r, --rewrite Transform/rewrite python
95+
-s, --schedule Schedule the model using CIRCT
96+
-v, --verilog Emit verilog
97+
-b, --testbench Run autogenerated testbench
98+
--wE WE Bit width of exponent
99+
--wF WF Bit width of fraction
100+
```
101+
102+
For example, running [examples/linear.py](examples/linear.py) produces an artifacts folder at [examples/linear_bragghls_artifacts](examples/linear_bragghls_artifacts) which will contains a `linear.mlir` file that looks like
103+
104+
```mlir
105+
module attributes {torch.debug_module_name = "Linear"} {
106+
memref.global "private" constant @__constant_8x8xf32 : memref<8x8xf32> = dense<[...]>
107+
memref.global "private" constant @__constant_8xf32 : memref<8xf32> = dense<[...]>
108+
func.func @forward(%arg0: memref<1x8xf32>) -> memref<f32> {
109+
110+
...
111+
112+
scf.for %arg1 = %c0 to %c1 step %c1 {
113+
scf.for %arg2 = %c0 to %c8 step %c1 {
114+
%7 = memref.load %4[%arg1, %arg2] : memref<1x8xf32>
115+
%8 = memref.load %6[] : memref<f32>
116+
%9 = arith.addf %7, %8 : f32
117+
memref.store %9, %6[] : memref<f32>
118+
}
119+
}
120+
return %6 : memref<f32>
121+
}
122+
}
123+
```
124+
125+
Then running (from top-level in the source directory)
126+
```shell
127+
python bragghls/compiler.py examples/linear_bragghls_artifacts/linear.mlir --t -r -s -v -b --wE 4 --wF 4
128+
```
129+
will generate `linear.sv` and run the automatically generated (no artifact) testbench, and produce the following output:
130+
131+
```
132+
INFO: Running command: iverilog "-o "examples/linear_bragghls_artifacts/sim.vvp "-D "COCOTB_SIM=1 "-g2012 "examples/linear_bragghls_artifacts/linear.sv "ip_cores/flopoco_fmul_4_4.sv "ip_cores/flopoco_fadd_4_4.sv "ip_cores/flopoco_relu.sv "ip_cores/flopoco_neg.sv" in directory:"examples/linear_bragghls_artifacts"
133+
0.00ns INFO Running on Icarus Verilog version 11.0 (stable)
134+
0.00ns INFO Running tests with cocotb v1.6.2 from /Users/mlevental/miniforge3/envs/bragghls/lib/python3.10/site-packages/cocotb
135+
0.00ns INFO Seeding Python random module with 1659448436
136+
0.00ns WARNING Pytest not found, assertion rewriting will not occur
137+
0.00ns INFO Found test tb_runner.test_tb
138+
0.00ns INFO running test_tb (1/1)
139+
140+
outputs {'_6': [<IEEE -4.6549486522000025> <FPNumber -4.50e0:01110010010>]}
141+
passed 43
142+
outputs {'_6': [<IEEE -1.2715176573999998> <FPNumber -1.31e0:01101110101>]}
143+
passed 87
144+
outputs {'_6': [<IEEE -7.192521898300005> <FPNumber -6.75e0:01110011011>]}
145+
passed 131
146+
outputs {'_6': [<IEEE -0.42565990870000003> <FPNumber -5.00e-1:01101100000>]}
147+
passed 175
148+
149+
...
150+
151+
passed 703
152+
outputs {'_6': [<IEEE 5.495344332200002> <FPNumber 5.00e0:01010010100>]}
153+
passed 747
154+
outputs {'_6': [<IEEE 4.6494865835> <FPNumber 5.25e0:01010010101>]}
155+
passed 791
156+
outputs {'_6': [<IEEE -2.963233154800001> <FPNumber -3.12e0:01110001001>]}
157+
passed 835
158+
outputs {'_6': [<IEEE 3.8036288347999996> <FPNumber 4.00e0:01010010000>]}
159+
passed 879
160+
161+
162+
1761.00ns INFO test_tb passed
163+
1761.00ns INFO **************************************************************************************
164+
** TEST STATUS SIM TIME (ns) REAL TIME (s) RATIO (ns/s) **
165+
**************************************************************************************
166+
** tb_runner.test_tb PASS 1761.00 1.08 1636.30 **
167+
**************************************************************************************
168+
** TESTS=1 PASS=1 FAIL=0 SKIP=0 1761.00 1.12 1571.26 **
169+
**************************************************************************************
170+
```

bragghls/compiler/compile.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,17 @@ def main(args):
129129
with open(f"{artifacts_dir}/{name}.rewritten.sched.mlir", "r") as f:
130130
sched_and_rewritten_mlir = f.read()
131131

132-
if args.verilog:
133-
(
134-
op_id_data,
135-
func_args,
136-
returns,
137-
return_time,
138-
vals,
139-
csts,
140-
pe_idxs,
141-
) = parse_mlir_module(sched_and_rewritten_mlir)
132+
(
133+
op_id_data,
134+
func_args,
135+
returns,
136+
return_time,
137+
vals,
138+
csts,
139+
pe_idxs,
140+
) = parse_mlir_module(sched_and_rewritten_mlir)
142141

142+
if args.verilog:
143143
verilog_file, input_wires, output_wires, max_fsm_stage = emit_verilog(
144144
name,
145145
args.wE,
@@ -158,6 +158,8 @@ def main(args):
158158
f.write(verilog_file)
159159

160160
print(f"{max_fsm_stage=}")
161+
else:
162+
max_fsm_stage = return_time + 1
161163

162164
if args.testbench:
163165
testbench_runner(
@@ -175,16 +177,16 @@ def main(args):
175177

176178

177179
if __name__ == "__main__":
178-
DEBUG = bool(int(os.getenv("DEBUG", "0")))
179-
parser = argparse.ArgumentParser()
180-
parser.add_argument("fp")
181-
parser.add_argument("-t", "--translate", default=False, action="store_true")
182-
parser.add_argument("-r", "--rewrite", default=False, action="store_true")
183-
parser.add_argument("-s", "--schedule", default=False, action="store_true")
184-
parser.add_argument("-v", "--verilog", default=False, action="store_true")
185-
parser.add_argument("-b", "--testbench", default=False, action="store_true")
186-
parser.add_argument("--wE", default=4)
187-
parser.add_argument("--wF", default=4)
180+
DEBUG = bool(int(os.getenv("DEBUG", "1")))
181+
parser = argparse.ArgumentParser("BraggHLS compiler driver")
182+
parser.add_argument("fp", help="Filepath of top-level MLIR file")
183+
parser.add_argument("-t", "--translate", default=False, action="store_true", help="Translate MLIR to python")
184+
parser.add_argument("-r", "--rewrite", default=False, action="store_true", help="Transform/rewrite python")
185+
parser.add_argument("-s", "--schedule", default=False, action="store_true", help="Schedule the model using CIRCT")
186+
parser.add_argument("-v", "--verilog", default=False, action="store_true", help="Emit verilog")
187+
parser.add_argument("-b", "--testbench", default=False, action="store_true", help="Run autogenerated testbench")
188+
parser.add_argument("--wE", default=4, help="Bit width of exponent")
189+
parser.add_argument("--wF", default=4, help="Bit width of fraction")
188190
args = parser.parse_args()
189191
args.wE = int(args.wE)
190192
args.wF = int(args.wF)

bragghls/flopoco/ops.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def ReduceAdd(vals):
3333
return pairs[0][0] + pairs[0][1]
3434

3535

36+
def check_make_val(v, wE, wF):
37+
if not isinstance(v, Val):
38+
assert isinstance(v, (float, int)), v
39+
v = Val(v, wE, wF)
40+
return v
41+
42+
3643
@dataclass(frozen=True)
3744
class Val:
3845
ieee: float
@@ -48,21 +55,27 @@ def __post_init__(self):
4855
)
4956
object.__setattr__(self, "name", str(self))
5057

51-
def __mul__(self, other: "Val"):
58+
def __mul__(self, other):
59+
other = check_make_val(other, self.wE, self.wF)
5260
v = mul(self, other)
5361
return v
5462

5563
def __eq__(self, other):
5664
return self.fp == other.fp
5765

58-
def __add__(self, other: "Val"):
66+
def __add__(self, other):
67+
other = check_make_val(other, self.wE, self.wF)
5968
v = add(self, other)
6069
return v
6170

62-
def __sub__(self, other: "Val"):
71+
def __sub__(self, other):
72+
other = check_make_val(other, self.wE, self.wF)
6373
v = sub(self, other)
6474
return v
6575

76+
def __neg__(self):
77+
return Val(-self.ieee, self.wE, self.wF)
78+
6679
def copy(self):
6780
return self
6881

bragghls/ir/parse.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,9 @@ def parse_mlir_module(module_str):
140140
start_time = int(start_time[0])
141141
else:
142142
start_time = None
143-
returns, return_time = idents[0][0], start_time
144-
if not isinstance(returns, list):
145-
returns = [returns]
146-
for r in returns:
147-
vals.add(r)
143+
returns, return_time = [idn[0] for idn in idents], start_time
144+
for r in returns:
145+
vals.add(r)
148146
else:
149147
continue
150148
assert func_args and returns

bragghls/rtl/emit_verilog.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def build_ip_res_val_map(pe, op_datas: list[Op], vals):
3131
else:
3232
warnings.warn(f"not mapping {res_val} to {op} in ip_res_val_map")
3333
elif op.type in {OpType.NEG, OpType.RELU}:
34-
ip_res_val_map[res_val] = pe.frelu.res
34+
ip = getattr(pe, op.type.value, None)
35+
ip_res_val_map[res_val] = ip.res
3536
elif op.type in {OpType.COPY}:
3637
pass
3738
elif op.type == OpType.FMAC:
@@ -178,6 +179,7 @@ def emit(*args):
178179
frelu = ReLU(pe_idx, signal_width)
179180
emit(frelu.instantiate())
180181
fneg = Neg(pe_idx, signal_width)
182+
emit(fneg.instantiate())
181183
pes[pe_idx] = PE(fadd, fmul, frelu, fneg, pe_idx)
182184

183185
pe_to_ops = cluster_pes(pes, op_id_data)

examples/braggnn.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
from pathlib import Path
34

45
import torch
@@ -306,9 +307,14 @@ def make_braggn(scale, img_size=11, simplify_weights=True):
306307

307308
if __name__ == "__main__":
308309
parser = argparse.ArgumentParser(description="make stuff")
309-
parser.add_argument("--out_dir", type=Path, default=Path("."))
310-
parser.add_argument("--scale", type=int, default=4)
310+
parser.add_argument(
311+
"--out_dir",
312+
type=Path,
313+
default=Path(__file__).parent / "braggnn_bragghls_artifacts",
314+
)
315+
parser.add_argument("--scale", type=int, default=1)
311316
args = parser.parse_args()
312317
args.out_dir = args.out_dir.resolve()
313318
dot_str = make_braggn(args.scale)
314-
open(f"{args.out_dir}/braggnn_{args.scale}.mlir", "w").write(dot_str)
319+
os.makedirs(f"{args.out_dir}", exist_ok=True)
320+
open(f"{args.out_dir}/braggnn.mlir", "w").write(dot_str)

examples/dot_product/dot_product.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ pybind11
1515
cmake
1616
ninja
1717

18-
-f https://github.com/llvm/torch-mlir/releases
19-
--pre
20-
torch-mlir
18+
# -f https://github.com/llvm/torch-mlir/releases
19+
# --pre
20+
# torch-mlir

tests/run_tests.sh

100644100755
File mode changed.

0 commit comments

Comments
 (0)