Skip to content

Commit 4efc9f5

Browse files
Forest AgostinelliForest Agostinelli
authored andcommitted
Separate timings for HER and rb [no ci]
1 parent f90945a commit 4efc9f5

6 files changed

Lines changed: 29 additions & 29 deletions

File tree

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
* Add layer norm to resnet_fc
2828
* Vectorize expand
2929
* Make policy at update_num=0 sampled from Domain's sample_state_action
30+
* Separate timings for HER and rb
3031

3132
## 0.2.1
3233
* Consolidate search: Beam search -> special cases: greedy_policy, graph search -> special cases: batch weighted A* search, batch weighted Q* search

deepxube/_solve.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def parse_solve(parser: ArgumentParser) -> None:
6060
parser.add_argument('--pathfind', type=str, required=True, help="Pathfinding algorithm and arguments.")
6161
parser.add_argument('--file', type=str, required=True, help="File containing problem instances to solve")
6262

63-
parser.add_argument('--time_limit', type=float, default=-1.0, help="A time limit for search. Default is -1, which means infinite.")
63+
parser.add_argument('--time_limit', type=float, default=-1.0, help="A time limit (in seconds) for search. Default is -1, which means infinite.")
6464
parser.add_argument('--max_itrs', type=int, default=None, help="Maximum number of search iterations. None for infinite.")
6565

6666
parser.add_argument('--results', type=str, required=True, help="Directory to save results. Saves results after every instance.")
@@ -124,10 +124,7 @@ def solve_cli(args: argparse.Namespace) -> None:
124124
# heur and policy fn
125125
heur_fn: Optional[HeurFn] = get_heur_fn(domain, domain_name, args.heur, args.heur_file, args.heur_type, args.nnet_batch_size)
126126
policy_fn: Optional[PolicyFn] = get_policy_fn(domain, domain_name, args.policy, args.policy_file, args.policy_samp, args.nnet_batch_size)
127-
print(domain)
128127
pathfind_functions: Any = get_pathfind_functions(get_pathfind_name_kwargs(args.pathfind)[0], heur_fn, policy_fn)
129-
pathfind: PathFind = get_pathfind_from_arg(domain, pathfind_functions, args.pathfind)[0]
130-
print(pathfind)
131128

132129
# get data
133130
data: Dict = pickle.load(open(args.file, "rb"))
@@ -152,6 +149,11 @@ def solve_cli(args: argparse.Namespace) -> None:
152149
if not args.debug:
153150
sys.stdout = data_utils.Logger(output_file, "w")
154151

152+
# print info
153+
print(domain)
154+
pathfind: PathFind = get_pathfind_from_arg(domain, pathfind_functions, args.pathfind)[0]
155+
print(pathfind)
156+
155157
start_idx: int
156158
if args.start_idx is not None:
157159
start_idx = args.start_idx

deepxube/base/updater.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,12 @@ def _get_her_goals(self, instances: List[Inst], times: Times) -> Tuple[List[Inst
511511
state_deepest = node.state
512512
states_deepest.append(state_deepest)
513513

514-
times.record_time("her_node_deepest", time.time() - start_time)
514+
times.record_time("node_deepest", time.time() - start_time, path=["HER"])
515515

516516
# relabel
517517
start_time = time.time()
518518
goals_relabel = self.domain.sample_goal_from_state(states_start, states_deepest)
519-
times.record_time("her_relabel", time.time() - start_time)
519+
times.record_time("relabel", time.time() - start_time, path=["HER"])
520520

521521
return instances_goalkeep + instances_relabel, goals_goalkeep + goals_relabel
522522

deepxube/updaters/updater_policy_rl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ def _init_replay_buffer(self, max_size: int) -> None:
7070
def _rb_add(self, states: List[State], goals: List[Goal], actions: List[Action], times: Times) -> None:
7171
start_time = time.time()
7272
self.rb.add(list(zip(states, goals, actions, strict=True)))
73-
times.record_time("rb_add", time.time() - start_time)
73+
times.record_time("add", time.time() - start_time, path=["replay"])
7474

7575
def _sample_rb(self, num: int, times: Times) -> Tuple[List[State], List[Goal], List[Action]]:
7676
# sample from replay buffer
7777
start_time = time.time()
7878
states, goals, actions = self.rb.sample(num)
79-
times.record_time("rb_samp", time.time() - start_time)
79+
times.record_time("samp", time.time() - start_time, path=["replay"])
8080

8181
return states, goals, actions
8282

@@ -155,7 +155,7 @@ def _get_instance_data_rb(self, instances: List[Instance], times: Times) -> List
155155
goals_her.extend([goal_her] * len(states_inst))
156156
actions_her.extend([edge.action for edge in instance.get_edges_popped()])
157157

158-
times.record_time("data_her", time.time() - start_time)
158+
times.record_time("data", time.time() - start_time, path=["HER"])
159159

160160
# add to replay buffer
161161
self._rb_add(states_her, goals_her, actions_her, times)

deepxube/updaters/updater_q_rl.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def __init__(self, domain: D, pathfind_arg: str, up_args: UpArgs):
5656
def _step(self, pathfind: PathFindSetHeurQ, times: Times) -> None:
5757
_pathfind_q_step(pathfind)
5858

59-
def _q_learning_target(self, goals: List[Goal], is_solved_l: List[bool], tcs: List[float], states_next: List[State], times: Times) -> List[float]:
60-
start_time = time.time()
59+
def _q_learning_target(self, goals: List[Goal], is_solved_l: List[bool], tcs: List[float], states_next: List[State]) -> List[float]:
6160
# min cost-to-go for next state
6261
actions_next: List[List[Action]] = self.get_pathfind().get_state_actions(states_next, goals)
6362
qvals_next_l: List[List[float]] = self._get_targ_heur_fn()(states_next, goals, actions_next)
@@ -67,8 +66,6 @@ def _q_learning_target(self, goals: List[Goal], is_solved_l: List[bool], tcs: Li
6766
ctg_backups: NDArray = np.array(tcs) + np.array(qvals_next_min)
6867
ctg_backups = ctg_backups * np.logical_not(np.array(is_solved_l))
6968

70-
times.record_time("qlearn_targ", time.time() - start_time)
71-
7269
return cast(List[float], ctg_backups.tolist())
7370

7471
def _inputs_ctgs_to_np(self, states: List[State], goals: List[Goal], actions: List[Action], ctgs_backup: List[float], times: Times) -> List[NDArray]:
@@ -85,16 +82,18 @@ def _rb_add(self, states: List[State], goals: List[Goal], is_solved_l: List[bool
8582
times: Times) -> None:
8683
start_time = time.time()
8784
self.rb.add(list(zip(states, goals, is_solved_l, actions, tcs, states_next, strict=True)))
88-
times.record_time("rb_add", time.time() - start_time)
85+
times.record_time("add", time.time() - start_time, path=["replay"])
8986

9087
def _sample_rb_qlearn_target(self, num: int, times: Times) -> Tuple[List[State], List[Goal], List[Action], List[float]]:
9188
# sample from replay buffer
9289
start_time = time.time()
9390
states, goals, is_solved_l, actions, tcs, states_next = self.rb.sample(num)
94-
times.record_time("rb_samp", time.time() - start_time)
91+
times.record_time("samp", time.time() - start_time, path=["replay"])
9592

9693
# value iteration update
97-
ctgs_backup: List[float] = self._q_learning_target(goals, is_solved_l, tcs, states_next, times)
94+
start_time = time.time()
95+
ctgs_backup: List[float] = self._q_learning_target(goals, is_solved_l, tcs, states_next)
96+
times.record_time("qlearn_targ", time.time() - start_time, path=["replay"])
9897

9998
return states, goals, actions, ctgs_backup
10099

@@ -201,12 +200,12 @@ def _get_instance_data_rb(self, instances: List[InstanceEdge], times: Times) ->
201200
tcs_her.append(tc)
202201
states_next_her.append(node_next.state)
203202

204-
times.record_time("data_her", time.time() - start_time)
203+
times.record_time("data", time.time() - start_time, path=["HER"])
205204

206205
# is solved
207206
start_time = time.time()
208207
is_solved_l_her: List[bool] = self.domain.is_solved(states_her, goals_her)
209-
times.record_time("is_solved_her", time.time() - start_time)
208+
times.record_time("is_solved", time.time() - start_time, path=["HER"])
210209

211210
# add to replay buffer
212211
self._rb_add(states_her, goals_her, is_solved_l_her, actions_her, tcs_her, states_next_her, times)

deepxube/updaters/updater_v_rl.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def __init__(self, domain: D, pathfind_arg: str, up_args: UpArgs):
5050
def _step(self, pathfind: PathFindSetHeurV, times: Times) -> None:
5151
_pathfind_v_step(pathfind)
5252

53-
def _value_iteration_target(self, goals: List[Goal], is_solved_l: List[bool], tcs_l: List[List[float]], states_exp: List[List[State]],
54-
times: Times) -> List[float]:
55-
start_time = time.time()
53+
def _value_iteration_target(self, goals: List[Goal], is_solved_l: List[bool], tcs_l: List[List[float]], states_exp: List[List[State]]) -> List[float]:
5654
# get cost-to-go of expanded states
5755
states_exp_flat, split_idxs = misc_utils.flatten(states_exp)
5856
goals_flat: List[Goal] = []
@@ -67,8 +65,6 @@ def _value_iteration_target(self, goals: List[Goal], is_solved_l: List[bool], tc
6765
ctgs_backup = np.array([np.min(x) for x in ctg_next_p_tc_l]) * np.logical_not(is_solved_l)
6866
ctgs_backup_l: List[float] = cast(List[float], ctgs_backup.tolist())
6967

70-
times.record_time("vi_targ", time.time() - start_time)
71-
7268
return ctgs_backup_l
7369

7470
def _inputs_ctgs_to_np(self, states: List[State], goals: List[Goal], ctgs_backup: List[float], times: Times) -> List[NDArray]:
@@ -85,21 +81,23 @@ def _init_replay_buffer(self, max_size: int) -> None:
8581
def _rb_add(self, states: List[State], goals: List[Goal], is_solved_l: List[bool], times: Times) -> None:
8682
start_time = time.time()
8783
self.rb.add(list(zip(states, goals, is_solved_l, strict=True)))
88-
times.record_time("rb_add", time.time() - start_time)
84+
times.record_time("add", time.time() - start_time, path=["replay"])
8985

9086
def _sample_rb_vi_target(self, num: int, times: Times) -> Tuple[List[State], List[Goal], List[float]]:
9187
# sample from replay buffer
9288
start_time = time.time()
9389
states, goals, is_solved_l = self.rb.sample(num)
94-
times.record_time("rb_samp", time.time() - start_time)
90+
times.record_time("samp", time.time() - start_time, path=["replay"])
9591

9692
# expand states
9793
start_time = time.time()
9894
states_exp, _, tcs_l = self.get_pathfind().expand_states(states, goals)
99-
times.record_time("vi_expand", time.time() - start_time)
95+
times.record_time("vi_expand", time.time() - start_time, path=["replay"])
10096

10197
# value iteration update
102-
ctgs_backup: List[float] = self._value_iteration_target(goals, is_solved_l, tcs_l, states_exp, times)
98+
start_time = time.time()
99+
ctgs_backup: List[float] = self._value_iteration_target(goals, is_solved_l, tcs_l, states_exp)
100+
times.record_time("vi_targ", time.time() - start_time, path=["replay"])
103101

104102
return states, goals, ctgs_backup
105103

@@ -191,12 +189,12 @@ def _get_instance_data_rb(self, instances: List[InstanceNode], times: Times) ->
191189
states_her.extend(states_inst)
192190
goals_her.extend([goal_her] * len(states_inst))
193191

194-
times.record_time("data_her", time.time() - start_time)
192+
times.record_time("data", time.time() - start_time, path=["HER"])
195193

196194
# is solved
197195
start_time = time.time()
198196
is_solved_l_her: List[bool] = self.domain.is_solved(states_her, goals_her)
199-
times.record_time("is_solved_her", time.time() - start_time)
197+
times.record_time("is_solved", time.time() - start_time, path=["HER"])
200198

201199
# add to replay buffer
202200
self._rb_add(states_her, goals_her, is_solved_l_her, times)

0 commit comments

Comments
 (0)