-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
29 lines (27 loc) · 1.06 KB
/
Copy pathtrainer.py
File metadata and controls
29 lines (27 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import graph
import os
import torch
import torch.nn as nn
class Trainer(nn.Module):
def __init__(self, params, model_dir: str = "models", seed: int = 42):
super(Trainer, self).__init__()
self.model_dir = model_dir
self.seed = seed
self.params = params
self.models = []
def forward(self):
if os.path.exists(self.model_dir):
for i in range(self.params["num_of_nets"]):
net = graph.Net(self.params, self.seed + i)
filename = f"{self.model_dir}/checkpoint_{i}.pth"
net.load_state_dict(torch.load(filename))
self.models.append(net)
else:
os.mkdir(self.model_dir)
for i in range(self.params["num_of_nets"]):
net = graph.Net(self.params, self.seed + i)
net.graph.add_edges(net.graph.nodes(), net.graph.nodes())
net.train_model()
self.models.append(net)
torch.save(net.state_dict(), f"{self.model_dir}/checkpoint_{i}.pth")
return self.models