|
| 1 | +import operator |
1 | 2 | from collections import namedtuple |
2 | 3 | from dataclasses import dataclass |
3 | 4 | from functools import reduce |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 |
|
7 | | -from bragghls.compiler import state |
8 | | -from bragghls.config import WIDTH_EXPONENT, WIDTH_FRACTION |
9 | | -from bragghls.util import idx_to_str, chunks |
10 | | - |
11 | 8 | try: |
12 | 9 | from . import flopoco_converter |
13 | 10 | except: |
14 | 11 | import flopoco_converter |
15 | 12 |
|
| 13 | +from bragghls.compiler import state |
| 14 | +from bragghls.config import WIDTH_EXPONENT, WIDTH_FRACTION |
| 15 | +from bragghls.util import idx_to_str, chunks |
| 16 | + |
16 | 17 | FPNUMBER = namedtuple("FPNUMBER", "pe_idx")(None) |
17 | 18 |
|
18 | 19 |
|
19 | | -def reducer(accum, val): |
| 20 | +def reducer(accum, val, reduce_op): |
20 | 21 | if len(val) > 1: |
21 | | - return accum + [val[0] + val[1]] |
| 22 | + res = reduce_op(val[0], val[1]) |
| 23 | + return accum + [res] |
22 | 24 | else: |
23 | 25 | return accum + val |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def ReduceAdd(vals): |
27 | 29 | pairs = list(chunks(list(vals), 2)) |
28 | 30 | while len(pairs) > 1: |
29 | | - pairs = list(chunks(reduce(reducer, pairs, []), 2)) |
| 31 | + pairs = list( |
| 32 | + chunks(reduce(lambda x, y: reducer(x, y, operator.add), pairs, []), 2) |
| 33 | + ) |
30 | 34 | return pairs[0][0] + pairs[0][1] |
31 | 35 |
|
32 | 36 |
|
| 37 | +def ReduceMax(vals): |
| 38 | + pairs = list(chunks(list(vals), 2)) |
| 39 | + while len(pairs) > 1: |
| 40 | + pairs = list(chunks(reduce(lambda x, y: reducer(x, y, max), pairs, []), 2)) |
| 41 | + return max(pairs[0][0], pairs[0][1]) |
| 42 | + |
| 43 | + |
33 | 44 | def check_make_val(v, width_exponent, width_fraction): |
34 | 45 | if not isinstance(v, Val): |
35 | 46 | assert isinstance(v, (float, int)), v |
@@ -70,6 +81,12 @@ def __eq__(self, other): |
70 | 81 | other = check_make_val(other, self.width_exponent, self.width_fraction) |
71 | 82 | return self.fp == other.fp |
72 | 83 |
|
| 84 | + def __lt__(self, other): |
| 85 | + other = check_make_val(other, self.width_exponent, self.width_fraction) |
| 86 | + # print(self.fp, other.fp, self.fp - other.fp, (self.fp - other.fp).sign()) |
| 87 | + # print((self.fp - other.fp).sign()) |
| 88 | + return (self.fp - other.fp).sign() == 1 |
| 89 | + |
73 | 90 | def __add__(self, other): |
74 | 91 | other = check_make_val(other, self.width_exponent, self.width_fraction) |
75 | 92 | v = add(self, other) |
@@ -97,6 +114,10 @@ def __repr__(self): |
97 | 114 | f"<IEEE {self.ieee:.5e}> {self.fp} {self.width_exponent} {self.width_fraction}" |
98 | 115 | ) |
99 | 116 |
|
| 117 | + @property |
| 118 | + def fp_float(self): |
| 119 | + return float(f"{str(self.fp).split(':')[0].split(' ')[1]}") |
| 120 | + |
100 | 121 |
|
101 | 122 | def mul(x: Val, y: Val): |
102 | 123 | assert x.width_exponent == y.width_exponent |
@@ -171,6 +192,9 @@ def numel(self): |
171 | 192 | def reduce_add(self): |
172 | 193 | return ReduceAdd(self.registers.flatten()) |
173 | 194 |
|
| 195 | + def reduce_max(self): |
| 196 | + return ReduceMax(list(self.registers.flatten())) |
| 197 | + |
174 | 198 | @property |
175 | 199 | def val_names_map(self): |
176 | 200 | assert self.input or self.output |
@@ -250,6 +274,9 @@ def __getitem__(self, index): |
250 | 274 | def numel(self): |
251 | 275 | return np.prod(self.shape) |
252 | 276 |
|
| 277 | + def reduce_max(self): |
| 278 | + return ReduceMax(list(self.vals.flatten())) |
| 279 | + |
253 | 280 | @staticmethod |
254 | 281 | def from_global_memref(memref, width_exponent, width_fraction): |
255 | 282 | return GlobalMemRef( |
@@ -302,8 +329,19 @@ def Div(cst, val): |
302 | 329 |
|
303 | 330 |
|
304 | 331 | def main(): |
305 | | - five = Val(4.0, 4, 4) |
306 | | - print(Div(1.0, five)) |
| 332 | + a = Val(2, 4, 4) |
| 333 | + b = Val(1, 4, 4) |
| 334 | + print(a - a) |
| 335 | + print(a, b) |
| 336 | + print(a - b) |
| 337 | + a = flopoco_converter.FPNumber(2, 4, 4) |
| 338 | + b = flopoco_converter.FPNumber(1, 4, 4) |
| 339 | + print(a + a) |
| 340 | + print(a - a) |
| 341 | + print(a - b) |
| 342 | + print(b - b) |
| 343 | + print(a - b) |
| 344 | + print(a + b) |
307 | 345 |
|
308 | 346 |
|
309 | 347 | if __name__ == "__main__": |
|
0 commit comments