Skip to content

Commit 5468742

Browse files
authored
Merge pull request #28 from vinhdc10998/Vinhdev
Fix gru and add multi model
2 parents 9352658 + 2abd3c3 commit 5468742

3 files changed

Lines changed: 73 additions & 4 deletions

File tree

model/gru_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, model_config, device, type_model):
1717
self.type_model = type_model
1818
self.device = device
1919

20-
self._features = torch.tensor(np.load(f'model/features/region_{self.region}_model_features.npy')).to(self.device)
20+
self.linear = nn.Linear(self.input_dim, self.feature_size, bias=True)
2121

2222
self.gru = nn.ModuleList(self._create_gru_cell(
2323
self.feature_size,
@@ -53,7 +53,7 @@ def forward(self, x):
5353
'''
5454
batch_size = x.shape[0]
5555
_input = torch.swapaxes(x, 0, 1)
56-
gru_inputs = torch.matmul(_input, self._features)
56+
gru_inputs = self.linear(_input)
5757
outputs, _ = self._compute_gru(self.gru, gru_inputs, batch_size)
5858

5959
logit_list = []

model/multi_model.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from torch import nn
3+
from .gru_model import GRUModel
4+
from torch.nn import functional as F
5+
6+
TYPE_MODEL = ['Hybrid']
7+
class MultiModel(nn.Module):
8+
def __init__(self,model_config, device, type_model=None):
9+
super(MultiModel,self).__init__()
10+
assert type_model in TYPE_MODEL
11+
self.num_classes = model_config['num_classes']
12+
self.num_outputs = model_config['num_outputs']
13+
self.type_model = type_model
14+
15+
self.lowerModel = GRUModel(model_config, device, type_model='Lower')
16+
self.higherModel = GRUModel(model_config, device, type_model='Higher')
17+
18+
self.lowerModel.load_state_dict(self.get_gru_layer(model_config['lower_path'], device))
19+
self.higherModel.load_state_dict(self.get_gru_layer(model_config['higher_path'], device))
20+
21+
for param in self.lowerModel.parameters():
22+
param.requires_grad = False
23+
for param in self.higherModel.parameters():
24+
param.requires_grad = False
25+
26+
self.linear = nn.ModuleList([nn.Linear(self.num_classes*2, self.num_classes) for _ in range(self.num_outputs)])
27+
28+
29+
@staticmethod
30+
def get_gru_layer(path, device):
31+
tmp = torch.load(path, map_location=torch.device(device))
32+
a = {}
33+
for i in tmp:
34+
if 'gru' in i:
35+
k = i[9:]
36+
a[k] = tmp[i]
37+
return a
38+
39+
40+
41+
def forward(self, input_):
42+
logits_1 = self.higherModel(input_)
43+
logits_2 = self.lowerModel(input_)
44+
logits = torch.cat((torch.stack(logits_1), torch.stack(logits_2)), dim=-1)
45+
logit_list = []
46+
for index, logit in enumerate(logits):
47+
logit_tmp = self.linear[index](logit)
48+
logit_list.append(logit_tmp)
49+
50+
logit = torch.cat(logit_list, dim=0)
51+
pred = F.softmax(torch.stack(logit_list), dim=-1)
52+
return logit, pred
53+
54+
55+

train.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from model.multi_model import MultiModel
12
import os
23
import json
34
import torch
@@ -16,14 +17,27 @@ def run(dataloader, model_config, args, region):
1617
type_model = args.model_type
1718
lr = args.learning_rate
1819
epochs = args.epochs
19-
gamma = args.gamma if type_model == 'Higher' else -args.gamma
20+
2021
output_model_dir = args.output_model_dir
2122
train_loader = dataloader['train']
2223
val_loader = dataloader['val']
2324
test_loader = dataloader['test']
2425

2526
#Init Model
26-
model = SingleModel(model_config, device, type_model=type_model).to(device)
27+
if type_model in ['Lower', 'Higher']:
28+
gamma = args.gamma if type_model == 'Higher' else -args.gamma
29+
model = SingleModel(model_config, device, type_model=type_model).to(device)
30+
31+
elif type_model in ['Hybrid']:
32+
gamma = 0
33+
if args.best_model:
34+
model_config['lower_path'] = os.path.join(args.model_dir, f'Best_Lower_region_{region}.pt')
35+
model_config['higher_path'] = os.path.join(args.model_dir, f'Best_Higher_region_{region}.pt')
36+
else:
37+
model_config['lower_path'] = os.path.join(args.model_dir, f'Lower_region_{region}.pt')
38+
model_config['higher_path'] = os.path.join(args.model_dir, f'Higher_region_{region}.pt')
39+
model = MultiModel(model_config, device, type_model=type_model).to(device)
40+
2741
def count_parameters(model):
2842
return sum(p.numel() for p in model.parameters() if p.requires_grad)
2943
print("Number of learnable parameters:",count_parameters(model))

0 commit comments

Comments
 (0)