Skip to content

Commit 8bd97d4

Browse files
Forest AgostinelliForest Agostinelli
authored andcommitted
layer_norm [no ci]
1 parent 9a667b3 commit 8bd97d4

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
* Add special case of training heuristic with random policy when using pathfinding algorithm that uses a policy and not training a policy
2525
* Add timing of getting supervised data to supervised pathfinding
2626
* Add abstract method for loss info for training policy
27+
* Add layer norm to resnet_fc
2728

2829
## 0.2.1
2930
* Consolidate search: Beam search -> special cases: greedy_policy, graph search -> special cases: batch weighted A* search, batch weighted Q* search

deepxube/heuristics/resnet_fc.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def nnet_input_type() -> Type[FlatIn]:
1919
return FlatIn
2020

2121
def __init__(self, nnet_input: FlatIn, out_dim: int, q_fix: bool, res_dim: int = 1000, num_blocks: int = 4,
22-
batch_norm: bool = False, weight_norm: bool = False, group_norm: int = -1, act_fn: str = "RELU"):
22+
batch_norm: bool = False, weight_norm: bool = False, layer_norm: bool = False, act_fn: str = "RELU"):
2323
super().__init__(nnet_input, out_dim, q_fix)
2424
# one hots
2525
self.one_hots: nn.ModuleList = nn.ModuleList()
@@ -33,6 +33,10 @@ def __init__(self, nnet_input: FlatIn, out_dim: int, q_fix: bool, res_dim: int =
3333
# res net
3434
self.res_dim: int = res_dim
3535

36+
group_norm: int = -1
37+
if layer_norm:
38+
group_norm = 1
39+
3640
def res_block_init() -> nn.Module:
3741
return FullyConnectedModel(res_dim, [res_dim] * 2, [act_fn, "LINEAR"],
3842
batch_norms=[batch_norm] * 2, weight_norms=[weight_norm] * 2,
@@ -57,7 +61,7 @@ def nnet_input_type() -> Type[FlatInPolicy]:
5761
return FlatInPolicy
5862

5963
def __init__(self, nnet_input: FlatInPolicy, num_samp: int, kl_weight: float, enc_dim: int = 10, res_dim: int = 1000, num_blocks: int = 4,
60-
batch_norm: bool = False, weight_norm: bool = False, group_norm: int = -1, act_fn: str = "RELU"):
64+
batch_norm: bool = False, weight_norm: bool = False, layer_norm: bool = False, act_fn: str = "RELU"):
6165
super().__init__(nnet_input, num_samp, kl_weight)
6266
# one hots
6367
input_dims, one_hot_depths = self.nnet_input.get_input_info()
@@ -75,6 +79,10 @@ def __init__(self, nnet_input: FlatInPolicy, num_samp: int, kl_weight: float, en
7579
# res net
7680
self.res_dim: int = res_dim
7781

82+
group_norm: int = -1
83+
if layer_norm:
84+
group_norm = 1
85+
7886
def res_block_init() -> nn.Module:
7987
return FullyConnectedModel(res_dim, [res_dim] * 2, [act_fn, "LINEAR"],
8088
batch_norms=[batch_norm] * 2, weight_norms=[weight_norm] * 2,
@@ -118,6 +126,7 @@ def parse(self, args_str: str) -> Dict[str, Any]:
118126
blocks_re = re.search(r"^(\S+)B$", args_str_i)
119127
bn_re = re.search(r"^bn$", args_str_i)
120128
wn_re = re.search(r"^wn$", args_str_i)
129+
ln_re = re.search(r"^ln$", args_str_i)
121130
if hidden_re is not None:
122131
kwargs["res_dim"] = int(hidden_re.group(1))
123132
elif blocks_re is not None:
@@ -126,13 +135,15 @@ def parse(self, args_str: str) -> Dict[str, Any]:
126135
kwargs["batch_norm"] = True
127136
elif wn_re is not None:
128137
kwargs["weight_norm"] = True
138+
elif ln_re is not None:
139+
kwargs["layer_norm"] = True
129140
else:
130141
raise ValueError(f"Unexpected argument {args_str_i!r}")
131142
return kwargs
132143

133144
def help(self) -> str:
134145
return ("Arguments are delimited by '_' and can be in any order.\n<num>H (number of hidden units), "
135-
"<num>B (number of blocks), bn (batch_norm), wn (weight_norm).\n"
146+
"<num>B (number of blocks), bn (batch_norm), wn (weight_norm), ln (layer_norm).\n"
136147
"E.g. resnet_fc.1000H_4B_bn")
137148

138149

@@ -149,6 +160,7 @@ def parse(self, args_str: str) -> Dict[str, Any]:
149160
kl_re = re.search(r"^(\S+)KL$", args_str_i)
150161
bn_re = re.search(r"^bn$", args_str_i)
151162
wn_re = re.search(r"^wn$", args_str_i)
163+
ln_re = re.search(r"^ln$", args_str_i)
152164
if hidden_re is not None:
153165
kwargs["res_dim"] = int(hidden_re.group(1))
154166
elif blocks_re is not None:
@@ -157,6 +169,8 @@ def parse(self, args_str: str) -> Dict[str, Any]:
157169
kwargs["batch_norm"] = True
158170
elif wn_re is not None:
159171
kwargs["weight_norm"] = True
172+
elif ln_re is not None:
173+
kwargs["layer_norm"] = True
160174
elif enc_dim_re is not None:
161175
kwargs["enc_dim"] = int(enc_dim_re.group(1))
162176
elif kl_re is not None:
@@ -167,5 +181,5 @@ def parse(self, args_str: str) -> Dict[str, Any]:
167181

168182
def help(self) -> str:
169183
return ("Arguments are delimited by '_' and can be in any order.\n<num>H (number of hidden units), "
170-
"<num>B (number of blocks), <enc_dim>E (encoding dimensionality), bn (batch_norm), wn (weight_norm).\n"
184+
"<num>B (number of blocks), <enc_dim>E (encoding dimensionality), bn (batch_norm), wn (weight_norm), ln (layer_norm).\n"
171185
"E.g. resnet_fc.1000H_4B_10E_bn")

0 commit comments

Comments
 (0)