@@ -19,7 +19,7 @@ def nnet_input_type() -> Type[FlatIn]:
1919 return FlatIn
2020
2121 def __init__ (self , nnet_input : FlatIn , out_dim : int , q_fix : bool , res_dim : int = 1000 , num_blocks : int = 4 ,
22- batch_norm : bool = False , weight_norm : bool = False , group_norm : int = - 1 , act_fn : str = "RELU" ):
22+ batch_norm : bool = False , weight_norm : bool = False , layer_norm : bool = False , act_fn : str = "RELU" ):
2323 super ().__init__ (nnet_input , out_dim , q_fix )
2424 # one hots
2525 self .one_hots : nn .ModuleList = nn .ModuleList ()
@@ -33,6 +33,10 @@ def __init__(self, nnet_input: FlatIn, out_dim: int, q_fix: bool, res_dim: int =
3333 # res net
3434 self .res_dim : int = res_dim
3535
36+ group_norm : int = - 1
37+ if layer_norm :
38+ group_norm = 1
39+
3640 def res_block_init () -> nn .Module :
3741 return FullyConnectedModel (res_dim , [res_dim ] * 2 , [act_fn , "LINEAR" ],
3842 batch_norms = [batch_norm ] * 2 , weight_norms = [weight_norm ] * 2 ,
@@ -57,7 +61,7 @@ def nnet_input_type() -> Type[FlatInPolicy]:
5761 return FlatInPolicy
5862
5963 def __init__ (self , nnet_input : FlatInPolicy , num_samp : int , kl_weight : float , enc_dim : int = 10 , res_dim : int = 1000 , num_blocks : int = 4 ,
60- batch_norm : bool = False , weight_norm : bool = False , group_norm : int = - 1 , act_fn : str = "RELU" ):
64+ batch_norm : bool = False , weight_norm : bool = False , layer_norm : bool = False , act_fn : str = "RELU" ):
6165 super ().__init__ (nnet_input , num_samp , kl_weight )
6266 # one hots
6367 input_dims , one_hot_depths = self .nnet_input .get_input_info ()
@@ -75,6 +79,10 @@ def __init__(self, nnet_input: FlatInPolicy, num_samp: int, kl_weight: float, en
7579 # res net
7680 self .res_dim : int = res_dim
7781
82+ group_norm : int = - 1
83+ if layer_norm :
84+ group_norm = 1
85+
7886 def res_block_init () -> nn .Module :
7987 return FullyConnectedModel (res_dim , [res_dim ] * 2 , [act_fn , "LINEAR" ],
8088 batch_norms = [batch_norm ] * 2 , weight_norms = [weight_norm ] * 2 ,
@@ -118,6 +126,7 @@ def parse(self, args_str: str) -> Dict[str, Any]:
118126 blocks_re = re .search (r"^(\S+)B$" , args_str_i )
119127 bn_re = re .search (r"^bn$" , args_str_i )
120128 wn_re = re .search (r"^wn$" , args_str_i )
129+ ln_re = re .search (r"^ln$" , args_str_i )
121130 if hidden_re is not None :
122131 kwargs ["res_dim" ] = int (hidden_re .group (1 ))
123132 elif blocks_re is not None :
@@ -126,13 +135,15 @@ def parse(self, args_str: str) -> Dict[str, Any]:
126135 kwargs ["batch_norm" ] = True
127136 elif wn_re is not None :
128137 kwargs ["weight_norm" ] = True
138+ elif ln_re is not None :
139+ kwargs ["layer_norm" ] = True
129140 else :
130141 raise ValueError (f"Unexpected argument { args_str_i !r} " )
131142 return kwargs
132143
133144 def help (self ) -> str :
134145 return ("Arguments are delimited by '_' and can be in any order.\n <num>H (number of hidden units), "
135- "<num>B (number of blocks), bn (batch_norm), wn (weight_norm).\n "
146+ "<num>B (number of blocks), bn (batch_norm), wn (weight_norm), ln (layer_norm) .\n "
136147 "E.g. resnet_fc.1000H_4B_bn" )
137148
138149
@@ -149,6 +160,7 @@ def parse(self, args_str: str) -> Dict[str, Any]:
149160 kl_re = re .search (r"^(\S+)KL$" , args_str_i )
150161 bn_re = re .search (r"^bn$" , args_str_i )
151162 wn_re = re .search (r"^wn$" , args_str_i )
163+ ln_re = re .search (r"^ln$" , args_str_i )
152164 if hidden_re is not None :
153165 kwargs ["res_dim" ] = int (hidden_re .group (1 ))
154166 elif blocks_re is not None :
@@ -157,6 +169,8 @@ def parse(self, args_str: str) -> Dict[str, Any]:
157169 kwargs ["batch_norm" ] = True
158170 elif wn_re is not None :
159171 kwargs ["weight_norm" ] = True
172+ elif ln_re is not None :
173+ kwargs ["layer_norm" ] = True
160174 elif enc_dim_re is not None :
161175 kwargs ["enc_dim" ] = int (enc_dim_re .group (1 ))
162176 elif kl_re is not None :
@@ -167,5 +181,5 @@ def parse(self, args_str: str) -> Dict[str, Any]:
167181
168182 def help (self ) -> str :
169183 return ("Arguments are delimited by '_' and can be in any order.\n <num>H (number of hidden units), "
170- "<num>B (number of blocks), <enc_dim>E (encoding dimensionality), bn (batch_norm), wn (weight_norm).\n "
184+ "<num>B (number of blocks), <enc_dim>E (encoding dimensionality), bn (batch_norm), wn (weight_norm), ln (layer_norm) .\n "
171185 "E.g. resnet_fc.1000H_4B_10E_bn" )
0 commit comments