1+ import torch
2+ from torch import nn
3+ from .gru_model import GRUModel
4+ from torch .nn import functional as F
5+
6+ TYPE_MODEL = ['Hybrid' ]
7+ class MultiModel (nn .Module ):
8+ def __init__ (self ,model_config , device , type_model = None ):
9+ super (MultiModel ,self ).__init__ ()
10+ assert type_model in TYPE_MODEL
11+ self .num_classes = model_config ['num_classes' ]
12+ self .num_outputs = model_config ['num_outputs' ]
13+ self .type_model = type_model
14+
15+ self .lowerModel = GRUModel (model_config , device , type_model = 'Lower' )
16+ self .higherModel = GRUModel (model_config , device , type_model = 'Higher' )
17+
18+ self .lowerModel .load_state_dict (self .get_gru_layer (model_config ['lower_path' ], device ))
19+ self .higherModel .load_state_dict (self .get_gru_layer (model_config ['higher_path' ], device ))
20+
21+ for param in self .lowerModel .parameters ():
22+ param .requires_grad = False
23+ for param in self .higherModel .parameters ():
24+ param .requires_grad = False
25+
26+ self .linear = nn .ModuleList ([nn .Linear (self .num_classes * 2 , self .num_classes ) for _ in range (self .num_outputs )])
27+
28+
29+ @staticmethod
30+ def get_gru_layer (path , device ):
31+ tmp = torch .load (path , map_location = torch .device (device ))
32+ a = {}
33+ for i in tmp :
34+ if 'gru' in i :
35+ k = i [9 :]
36+ a [k ] = tmp [i ]
37+ return a
38+
39+
40+
41+ def forward (self , input_ ):
42+ logits_1 = self .higherModel (input_ )
43+ logits_2 = self .lowerModel (input_ )
44+ logits = torch .cat ((torch .stack (logits_1 ), torch .stack (logits_2 )), dim = - 1 )
45+ logit_list = []
46+ for index , logit in enumerate (logits ):
47+ logit_tmp = self .linear [index ](logit )
48+ logit_list .append (logit_tmp )
49+
50+ logit = torch .cat (logit_list , dim = 0 )
51+ pred = F .softmax (torch .stack (logit_list ), dim = - 1 )
52+ return logit , pred
53+
54+
55+
0 commit comments