-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathbenchmarking_utils.py
More file actions
69 lines (60 loc) · 2.25 KB
/
Copy pathbenchmarking_utils.py
File metadata and controls
69 lines (60 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import time
import pandas
import torch
from tools.utils import as_col_major
torch.set_grad_enabled(False)
@torch.no_grad
def run_benchmark(
*,
perf_func,
a: torch.Tensor,
b: torch.Tensor,
b_col_major: torch.Tensor,
out: torch.Tensor,
):
tag = perf_func.__name__
out.fill_(0)
torch.cuda.synchronize()
t0 = time.time()
if tag != "matmul":
perf_func(a, b, b_col_major, out)
else:
perf_func(a, b, out=out)
torch.cuda.synchronize()
t1 = time.time()
elapsed_time_ms = (t1 - t0) * 1000
return out, elapsed_time_ms
def run_all_perf_funcs_once(*, perf_func_list, m, n, k, acc_precise, device_type, padding_m, padding_k, padding_n):
a = torch.randn((m, k), dtype=torch.half, device="cuda").cuda()
b = torch.randn((k, n), dtype=torch.half, device="cuda").cuda()
a_list, b_list, b_col_major_list, c_list = [], [], [], []
for perf_func in perf_func_list:
func_name = perf_func.__name__
if func_name == f"cuda_l2_{device_type}_{acc_precise}":
a_use = torch.zeros((m+padding_m, k+padding_k), dtype=torch.half, device="cuda").cuda()
a_use[:m, :k] = a.clone()
b_use = torch.zeros((k+padding_k, n+padding_n), dtype=torch.half, device="cuda").cuda()
b_use[:k, :n] = b.clone()
b_col_major_use = as_col_major(b_use)
c_use = torch.randn((m+padding_m, n+padding_n), dtype=torch.half, device="cuda").cuda()
else:
a_use = a.clone()
b_use = b.clone()
b_col_major_use = as_col_major(b_use)
c_use = torch.randn((m, n), dtype=torch.half, device="cuda").cuda()
a_list.append(a_use)
b_list.append(b_use)
b_col_major_list.append(b_col_major_use)
c_list.append(c_use)
torch.cuda.synchronize()
record = dict()
for i, perf_func in enumerate(perf_func_list):
_, elapsed_time_ms = run_benchmark(
perf_func=perf_func, a=a_list[i], b=b_list[i], b_col_major=b_col_major_list[i], out=c_list[i],
)
func_name = perf_func.__name__
tflops = (2 * m * n * k) * 1e-12 * 1000 / (elapsed_time_ms)
record[func_name] = tflops
record[func_name+"_ms"] = elapsed_time_ms
return record