Skip to content
This repository was archived by the owner on May 5, 2024. It is now read-only.

Commit 230a72e

Browse files
committed
refactor braggnn and add make_top for partitioned braggnn
1 parent b345871 commit 230a72e

40 files changed

Lines changed: 13456 additions & 225 deletions

README.md

Lines changed: 31 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -28,106 +28,49 @@ The "flow" is
2828
Turn this
2929

3030
```python
31-
class DoubleCNN(nn.Module):
32-
def __init__(self, scale):
33-
super().__init__()
34-
self.conv1 = torch.nn.Conv2d(1, 16 * scale, 3)
35-
self.conv2_1 = torch.nn.Conv2d(16 * scale, 8 * scale, 1)
36-
self.conv2_2 = torch.nn.Conv2d(16 * scale, 8 * scale, 1)
37-
self.conv2_3 = torch.nn.Conv2d(16 * scale, 8 * scale, 1)
38-
self.conv3 = torch.nn.Conv2d(8 * scale, 16 * scale, 1)
39-
self.conv4 = torch.nn.Conv2d(16 * scale, 8 * scale, 3)
40-
41-
def forward(self, x):
42-
y = self.conv1(x)
43-
z = self.conv2_1(y)
44-
w = self.conv2_2(y)
45-
u = self.conv2_3(y)
46-
uuu = z + w + u
47-
uu = self.conv3(uuu)
48-
return uu.sum()
49-
```
50-
51-
into this
52-
53-
```mlir
54-
#map = affine_map<(d0, d1) -> (d0 + d1)>
55-
module attributes {torch.debug_module_name = "DoubleCNN"} {
56-
memref.global "private" constant @__constant_16x1x3x3xf32 : memref<16x1x3x3xf32> = dense<"...">
57-
memref.global "private" constant @__constant_16xf32_0 : memref<16xf32> = dense<[0.243066281, 0.331322402, ...]>
58-
memref.global "private" constant @__constant_8x16x1x1xf32_1 : memref<8x16x1x1xf32> = dense<"...">
59-
memref.global "private" constant @__constant_8xf32_1 : memref<8xf32> = dense<[0.0737214088, 0.0993697941, ...]>
60-
memref.global "private" constant @__constant_8x16x1x1xf32_0 : memref<8x16x1x1xf32> = dense<"...">
61-
memref.global "private" constant @__constant_8xf32_0 : memref<8xf32> = dense<[0.0834305584, -0.150565714, ...]>
62-
memref.global "private" constant @__constant_8x16x1x1xf32 : memref<8x16x1x1xf32> = dense<"...">
63-
memref.global "private" constant @__constant_8xf32 : memref<8xf32> = dense<[-0.0900013148, -0.189049691,...]>
64-
memref.global "private" constant @__constant_16x8x1x1xf32 : memref<16x8x1x1xf32> = dense<"...">
65-
memref.global "private" constant @__constant_16xf32 : memref<16xf32> = dense<[-0.133005634, -0.297289908, ...]>
66-
func.func @forward(%arg0: memref<1x1x11x11xf32>) -> memref<f32> {
67-
%11 = memref.alloca() : memref<1x16x9x9xf32>
68-
memref.copy %10, %11 : memref<1x16x9x9xf32> to memref<1x16x9x9xf32>
69-
scf.parallel (%arg1, %arg2, %arg3, %arg4) = (%c0, %c0, %c0, %c0) to (%c1, %c16, %c9, %c9) step (%c1, %c1, %c1, %c1) {
70-
scf.for %arg5 = %c0 to %c1 step %c1 {
71-
scf.for %arg6 = %c0 to %c3 step %c1 {
72-
scf.for %arg7 = %c0 to %c3 step %c1 {
73-
%24 = affine.apply #map(%arg3, %arg6)
74-
%25 = affine.apply #map(%arg4, %arg7)
75-
%26 = memref.load %arg0[%arg1, %arg5, %24, %25] : memref<1x1x11x11xf32>
76-
%27 = memref.load %9[%arg2, %arg5, %arg6, %arg7] : memref<16x1x3x3xf32>
77-
%28 = memref.load %11[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
78-
%29 = arith.mulf %26, %27 : f32
79-
%30 = arith.addf %28, %29 : f32
80-
memref.store %30, %11[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
81-
}
82-
}
83-
}
84-
scf.yield
85-
}
86-
87-
...
88-
89-
}
90-
%22 = memref.alloca() : memref<f32>
91-
memref.store %cst, %22[] : memref<f32>
92-
%23 = memref.alloc() {alignment = 128 : i64} : memref<f32>
93-
memref.copy %22, %23 : memref<f32> to memref<f32>
94-
scf.for %arg1 = %c0 to %c1 step %c1 {
95-
scf.for %arg2 = %c0 to %c16 step %c1 {
96-
scf.for %arg3 = %c0 to %c9 step %c1 {
97-
scf.for %arg4 = %c0 to %c9 step %c1 {
98-
%24 = memref.load %21[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
99-
%25 = memref.load %23[] : memref<f32>
100-
%26 = arith.addf %24, %25 : f32
101-
memref.store %26, %23[] : memref<f32>
102-
}
103-
}
104-
}
105-
}
106-
return %23 : memref<f32>
107-
}
108-
}
109-
31+
BraggNN(
32+
(cnn_layers_1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
33+
(nlb): NLB(
34+
(theta_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
35+
(phi_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
36+
(g_layer): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
37+
(out_cnn): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
38+
(soft): Softmax(
39+
(exp): Exp()
40+
)
41+
)
42+
(cnn_layers_2): Sequential(
43+
(0): ReLU()
44+
(1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1))
45+
(2): ReLU()
46+
(3): Conv2d(8, 2, kernel_size=(3, 3), stride=(1, 1))
47+
(4): ReLU()
48+
)
49+
(dense_layers): Sequential(
50+
(0): Linear(in_features=50, out_features=16, bias=True)
51+
(1): ReLU()
52+
(2): Linear(in_features=16, out_features=8, bias=True)
53+
(3): ReLU()
54+
(4): Linear(in_features=8, out_features=4, bias=True)
55+
(5): ReLU()
56+
(6): Linear(in_features=4, out_features=2, bias=True)
57+
(7): ReLU()
58+
)
59+
)
11060
```
11161

11262
into this
11363

11464
<p align="center">
115-
<img height="1000" src="docs/images/double_cnn.png" alt="">
65+
<img height="1000" src="docs/images/bragghls_done.png" alt="">
11666
</p>
11767
<p align="center">
118-
245 intervals at ~100 MHz on Xilinx Alveo U280
68+
1200 intervals at ~100 MHz on Xilinx Alveo U280
11969
</p>
12070
<p align="center">
12171
(Red represents FMUL logic, green represents FADD logic)
12272
</p>
12373

124-
<!---
125-
126-
[//]: # (![alt text]&#40;docs/images/double_cnn.png&#41;)
127-
[//]: # (3:#highlight_objects -color green -leaf_cells [get_cells _forward_inner/fadd*]
128-
[//]: # (54:#highlight_objects -color red -leaf_cells [get_cells _forward_inner/fmul*])
129-
--->
130-
13174
# Repo structure
13275

13376
This project has a lot of moving parts; the directory structure tells the tale:

bragghls/compiler/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def compile(
203203
vals,
204204
csts,
205205
pe_idxs,
206-
include_outer_module=not do_testbench,
206+
for_testbench=do_testbench
207207
)
208208
verilog_file = verilog_file.replace("%", "p_")
209209
with open(f"{artifacts_dir}/{name}.sv", "w") as f:

bragghls/ip_cores/flopoco_fadd_3_3.sv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
(* use_dsp = "yes" *) module intadder_8_f300_uid133
1+
module intadder_8_f300_uid133
22
(input wire clk,
33
input wire [7:0] x,
44
input wire [7:0] y,
@@ -33,7 +33,7 @@
3333
n438_q <= y_d1;
3434
endmodule
3535

36-
(* use_dsp = "yes" *) module intadder_7_f300_uid10
36+
module intadder_7_f300_uid10
3737
(input wire clk,
3838
input wire [6:0] x,
3939
input wire [6:0] y,

0 commit comments

Comments
 (0)