-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathutils.py
More file actions
59 lines (43 loc) · 1.31 KB
/
Copy pathutils.py
File metadata and controls
59 lines (43 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import random
import numpy as np
import pandas as pd
import torch
from config import CFG
def seed_everything(seed=1234):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=CFG.device))
== 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float(
'-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(tgt):
"""
tgt: shape(N, L)
"""
tgt_seq_len = tgt.shape[1]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
tgt_padding_mask = (tgt == CFG.pad_idx)
return tgt_mask, tgt_padding_mask
class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0]*3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]