@@ -50,9 +50,7 @@ def __init__(self, domain: D, pathfind_arg: str, up_args: UpArgs):
5050 def _step (self , pathfind : PathFindSetHeurV , times : Times ) -> None :
5151 _pathfind_v_step (pathfind )
5252
53- def _value_iteration_target (self , goals : List [Goal ], is_solved_l : List [bool ], tcs_l : List [List [float ]], states_exp : List [List [State ]],
54- times : Times ) -> List [float ]:
55- start_time = time .time ()
53+ def _value_iteration_target (self , goals : List [Goal ], is_solved_l : List [bool ], tcs_l : List [List [float ]], states_exp : List [List [State ]]) -> List [float ]:
5654 # get cost-to-go of expanded states
5755 states_exp_flat , split_idxs = misc_utils .flatten (states_exp )
5856 goals_flat : List [Goal ] = []
@@ -67,8 +65,6 @@ def _value_iteration_target(self, goals: List[Goal], is_solved_l: List[bool], tc
6765 ctgs_backup = np .array ([np .min (x ) for x in ctg_next_p_tc_l ]) * np .logical_not (is_solved_l )
6866 ctgs_backup_l : List [float ] = cast (List [float ], ctgs_backup .tolist ())
6967
70- times .record_time ("vi_targ" , time .time () - start_time )
71-
7268 return ctgs_backup_l
7369
7470 def _inputs_ctgs_to_np (self , states : List [State ], goals : List [Goal ], ctgs_backup : List [float ], times : Times ) -> List [NDArray ]:
@@ -85,21 +81,23 @@ def _init_replay_buffer(self, max_size: int) -> None:
8581 def _rb_add (self , states : List [State ], goals : List [Goal ], is_solved_l : List [bool ], times : Times ) -> None :
8682 start_time = time .time ()
8783 self .rb .add (list (zip (states , goals , is_solved_l , strict = True )))
88- times .record_time ("rb_add " , time .time () - start_time )
84+ times .record_time ("add " , time .time () - start_time , path = [ "replay" ] )
8985
9086 def _sample_rb_vi_target (self , num : int , times : Times ) -> Tuple [List [State ], List [Goal ], List [float ]]:
9187 # sample from replay buffer
9288 start_time = time .time ()
9389 states , goals , is_solved_l = self .rb .sample (num )
94- times .record_time ("rb_samp " , time .time () - start_time )
90+ times .record_time ("samp " , time .time () - start_time , path = [ "replay" ] )
9591
9692 # expand states
9793 start_time = time .time ()
9894 states_exp , _ , tcs_l = self .get_pathfind ().expand_states (states , goals )
99- times .record_time ("vi_expand" , time .time () - start_time )
95+ times .record_time ("vi_expand" , time .time () - start_time , path = [ "replay" ] )
10096
10197 # value iteration update
102- ctgs_backup : List [float ] = self ._value_iteration_target (goals , is_solved_l , tcs_l , states_exp , times )
98+ start_time = time .time ()
99+ ctgs_backup : List [float ] = self ._value_iteration_target (goals , is_solved_l , tcs_l , states_exp )
100+ times .record_time ("vi_targ" , time .time () - start_time , path = ["replay" ])
103101
104102 return states , goals , ctgs_backup
105103
@@ -191,12 +189,12 @@ def _get_instance_data_rb(self, instances: List[InstanceNode], times: Times) ->
191189 states_her .extend (states_inst )
192190 goals_her .extend ([goal_her ] * len (states_inst ))
193191
194- times .record_time ("data_her " , time .time () - start_time )
192+ times .record_time ("data " , time .time () - start_time , path = [ "HER" ] )
195193
196194 # is solved
197195 start_time = time .time ()
198196 is_solved_l_her : List [bool ] = self .domain .is_solved (states_her , goals_her )
199- times .record_time ("is_solved_her " , time .time () - start_time )
197+ times .record_time ("is_solved " , time .time () - start_time , path = [ "HER" ] )
200198
201199 # add to replay buffer
202200 self ._rb_add (states_her , goals_her , is_solved_l_her , times )
0 commit comments