@@ -28,106 +28,49 @@ The "flow" is
2828Turn this
2929
3030``` python
31- class DoubleCNN (nn .Module ):
32- def __init__ (self , scale ):
33- super ().__init__ ()
34- self .conv1 = torch.nn.Conv2d(1 , 16 * scale, 3 )
35- self .conv2_1 = torch.nn.Conv2d(16 * scale, 8 * scale, 1 )
36- self .conv2_2 = torch.nn.Conv2d(16 * scale, 8 * scale, 1 )
37- self .conv2_3 = torch.nn.Conv2d(16 * scale, 8 * scale, 1 )
38- self .conv3 = torch.nn.Conv2d(8 * scale, 16 * scale, 1 )
39- self .conv4 = torch.nn.Conv2d(16 * scale, 8 * scale, 3 )
40-
41- def forward (self , x ):
42- y = self .conv1(x)
43- z = self .conv2_1(y)
44- w = self .conv2_2(y)
45- u = self .conv2_3(y)
46- uuu = z + w + u
47- uu = self .conv3(uuu)
48- return uu.sum()
49- ```
50-
51- into this
52-
53- ``` mlir
54- #map = affine_map<(d0, d1) -> (d0 + d1)>
55- module attributes {torch.debug_module_name = "DoubleCNN"} {
56- memref.global "private" constant @__constant_16x1x3x3xf32 : memref<16x1x3x3xf32> = dense<"...">
57- memref.global "private" constant @__constant_16xf32_0 : memref<16xf32> = dense<[0.243066281, 0.331322402, ...]>
58- memref.global "private" constant @__constant_8x16x1x1xf32_1 : memref<8x16x1x1xf32> = dense<"...">
59- memref.global "private" constant @__constant_8xf32_1 : memref<8xf32> = dense<[0.0737214088, 0.0993697941, ...]>
60- memref.global "private" constant @__constant_8x16x1x1xf32_0 : memref<8x16x1x1xf32> = dense<"...">
61- memref.global "private" constant @__constant_8xf32_0 : memref<8xf32> = dense<[0.0834305584, -0.150565714, ...]>
62- memref.global "private" constant @__constant_8x16x1x1xf32 : memref<8x16x1x1xf32> = dense<"...">
63- memref.global "private" constant @__constant_8xf32 : memref<8xf32> = dense<[-0.0900013148, -0.189049691,...]>
64- memref.global "private" constant @__constant_16x8x1x1xf32 : memref<16x8x1x1xf32> = dense<"...">
65- memref.global "private" constant @__constant_16xf32 : memref<16xf32> = dense<[-0.133005634, -0.297289908, ...]>
66- func.func @forward(%arg0: memref<1x1x11x11xf32>) -> memref<f32> {
67- %11 = memref.alloca() : memref<1x16x9x9xf32>
68- memref.copy %10, %11 : memref<1x16x9x9xf32> to memref<1x16x9x9xf32>
69- scf.parallel (%arg1, %arg2, %arg3, %arg4) = (%c0, %c0, %c0, %c0) to (%c1, %c16, %c9, %c9) step (%c1, %c1, %c1, %c1) {
70- scf.for %arg5 = %c0 to %c1 step %c1 {
71- scf.for %arg6 = %c0 to %c3 step %c1 {
72- scf.for %arg7 = %c0 to %c3 step %c1 {
73- %24 = affine.apply #map(%arg3, %arg6)
74- %25 = affine.apply #map(%arg4, %arg7)
75- %26 = memref.load %arg0[%arg1, %arg5, %24, %25] : memref<1x1x11x11xf32>
76- %27 = memref.load %9[%arg2, %arg5, %arg6, %arg7] : memref<16x1x3x3xf32>
77- %28 = memref.load %11[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
78- %29 = arith.mulf %26, %27 : f32
79- %30 = arith.addf %28, %29 : f32
80- memref.store %30, %11[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
81- }
82- }
83- }
84- scf.yield
85- }
86-
87- ...
88-
89- }
90- %22 = memref.alloca() : memref<f32>
91- memref.store %cst, %22[] : memref<f32>
92- %23 = memref.alloc() {alignment = 128 : i64} : memref<f32>
93- memref.copy %22, %23 : memref<f32> to memref<f32>
94- scf.for %arg1 = %c0 to %c1 step %c1 {
95- scf.for %arg2 = %c0 to %c16 step %c1 {
96- scf.for %arg3 = %c0 to %c9 step %c1 {
97- scf.for %arg4 = %c0 to %c9 step %c1 {
98- %24 = memref.load %21[%arg1, %arg2, %arg3, %arg4] : memref<1x16x9x9xf32>
99- %25 = memref.load %23[] : memref<f32>
100- %26 = arith.addf %24, %25 : f32
101- memref.store %26, %23[] : memref<f32>
102- }
103- }
104- }
105- }
106- return %23 : memref<f32>
107- }
108- }
109-
31+ BraggNN(
32+ (cnn_layers_1): Conv2d(1 , 16 , kernel_size = (3 , 3 ), stride = (1 , 1 ))
33+ (nlb): NLB(
34+ (theta_layer): Conv2d(16 , 8 , kernel_size = (1 , 1 ), stride = (1 , 1 ))
35+ (phi_layer): Conv2d(16 , 8 , kernel_size = (1 , 1 ), stride = (1 , 1 ))
36+ (g_layer): Conv2d(16 , 8 , kernel_size = (1 , 1 ), stride = (1 , 1 ))
37+ (out_cnn): Conv2d(8 , 16 , kernel_size = (1 , 1 ), stride = (1 , 1 ))
38+ (soft): Softmax(
39+ (exp): Exp()
40+ )
41+ )
42+ (cnn_layers_2): Sequential(
43+ (0 ): ReLU()
44+ (1 ): Conv2d(16 , 8 , kernel_size = (3 , 3 ), stride = (1 , 1 ))
45+ (2 ): ReLU()
46+ (3 ): Conv2d(8 , 2 , kernel_size = (3 , 3 ), stride = (1 , 1 ))
47+ (4 ): ReLU()
48+ )
49+ (dense_layers): Sequential(
50+ (0 ): Linear(in_features = 50 , out_features = 16 , bias = True )
51+ (1 ): ReLU()
52+ (2 ): Linear(in_features = 16 , out_features = 8 , bias = True )
53+ (3 ): ReLU()
54+ (4 ): Linear(in_features = 8 , out_features = 4 , bias = True )
55+ (5 ): ReLU()
56+ (6 ): Linear(in_features = 4 , out_features = 2 , bias = True )
57+ (7 ): ReLU()
58+ )
59+ )
11060```
11161
11262into this
11363
11464<p align =" center " >
115- <img height =" 1000 " src =" docs/images/double_cnn .png " alt =" " >
65+ <img height =" 1000 " src =" docs/images/bragghls_done .png " alt =" " >
11666</p >
11767<p align =" center " >
118- 245 intervals at ~ 100 MHz on Xilinx Alveo U280
68+ 1200 intervals at ~ 100 MHz on Xilinx Alveo U280
11969</p >
12070<p align =" center " >
12171 (Red represents FMUL logic, green represents FADD logic)
12272</p >
12373
124- <!-- -
125-
126- [//]: # ()
127- [//]: # (3:#highlight_objects -color green -leaf_cells [get_cells _forward_inner/fadd*]
128- [//]: # (54:#highlight_objects -color red -leaf_cells [get_cells _forward_inner/fmul*])
129- --->
130-
13174# Repo structure
13275
13376This project has a lot of moving parts; the directory structure tells the tale:
0 commit comments