@@ -2074,6 +2074,7 @@ def unslice_all(self, inplace=False):
20742074 unslice_all_ = functools .partialmethod (unslice_all , inplace = True )
20752075
20762076 def calc_subtree_candidates (self , pwr = 2 , what = "flops" ):
2077+ # get all intermediate nodes
20772078 candidates = list (self .children )
20782079
20792080 if what == "size" :
@@ -2082,10 +2083,12 @@ def calc_subtree_candidates(self, pwr=2, what="flops"):
20822083 elif what == "flops" :
20832084 weights = [self .get_flops (x ) for x in candidates ]
20842085
2085- max_weight = max (weights )
2086-
2087- # can be bigger than numpy int/float allows
2088- weights = [float (w / max_weight ) ** (1 / pwr ) for w in weights ]
2086+ if pwr == "log" :
2087+ weights = [math .log2 (max (2 , w )) for w in weights ]
2088+ else :
2089+ max_weight = max (weights )
2090+ # can be bigger than numpy int/float allows
2091+ weights = [float (w / max_weight ) ** (1 / pwr ) for w in weights ]
20892092
20902093 # sort by descending score
20912094 candidates , weights = zip (
@@ -2094,6 +2097,164 @@ def calc_subtree_candidates(self, pwr=2, what="flops"):
20942097
20952098 return list (candidates ), list (weights )
20962099
2100+ def _subtree_remove_and_optimize (
2101+ self ,
2102+ sub_root ,
2103+ sub_leaves ,
2104+ sub_branches ,
2105+ already_optimized ,
2106+ node_cost ,
2107+ minimize ,
2108+ opt ,
2109+ pbar ,
2110+ ):
2111+ current_cost = node_cost (self , sub_root )
2112+ for node in sub_branches :
2113+ # these are the intermediates *between* leaves and sub-root
2114+ if minimize == "size" :
2115+ current_cost = max (current_cost , node_cost (self , node ))
2116+ else :
2117+ current_cost += node_cost (self , node )
2118+ self ._remove_node (node )
2119+
2120+ # make the optimizer more efficient by supplying accurate cap
2121+ opt .cost_cap = max (2 , current_cost )
2122+
2123+ # and reoptimize the leaves
2124+ self .contract_nodes (sub_leaves , optimize = opt , grandparent = sub_root )
2125+ already_optimized .add (sub_leaves )
2126+
2127+ if pbar is not None :
2128+ pbar .update ()
2129+ pbar .set_description (_describe_tree (self ), refresh = False )
2130+
2131+ def _subtree_reconfigure_descend (
2132+ self ,
2133+ subtree_size ,
2134+ subtree_search ,
2135+ maxiter ,
2136+ seed ,
2137+ minimize ,
2138+ opt ,
2139+ already_optimized ,
2140+ node_cost ,
2141+ pbar ,
2142+ ):
2143+ candidates = [self .root ]
2144+ any_modified = False
2145+
2146+ def _possibly_add_children (sub_root , any_modified ):
2147+ if self .get_extent (sub_root ) > subtree_size :
2148+ # possibly extend with node children, if not close to bottom
2149+ lnode , rnode = self .children [sub_root ]
2150+ if self .get_extent (lnode ) >= 2 :
2151+ candidates .append (lnode )
2152+ if self .get_extent (rnode ) >= 2 :
2153+ candidates .append (rnode )
2154+
2155+ if len (candidates ) == 0 :
2156+ # exhausted queue
2157+ if any_modified :
2158+ # but have made *any* changes -> go again from top
2159+ candidates .append (self .root )
2160+ any_modified = False
2161+
2162+ return any_modified
2163+
2164+ r = 0
2165+ while candidates and r < maxiter :
2166+ sub_root = candidates .pop (0 )
2167+
2168+ # get a subtree to possibly reconfigure
2169+ sub_leaves , sub_branches = self .get_subtree (
2170+ sub_root , size = subtree_size , search = subtree_search , seed = seed
2171+ )
2172+
2173+ # check if its already been optimized
2174+ sub_leaves = frozenset (sub_leaves )
2175+ if sub_leaves in already_optimized :
2176+ any_modified = _possibly_add_children (sub_root , any_modified )
2177+ continue
2178+
2179+ # else remove the branches, keeping track of current cost
2180+ self ._subtree_remove_and_optimize (
2181+ sub_root ,
2182+ sub_leaves ,
2183+ sub_branches ,
2184+ already_optimized ,
2185+ node_cost ,
2186+ minimize ,
2187+ opt ,
2188+ pbar ,
2189+ )
2190+ any_modified = _possibly_add_children (sub_root , True )
2191+ r += 1
2192+
2193+ def _subtree_reconfigure_rand_select (
2194+ self ,
2195+ subtree_size ,
2196+ subtree_search ,
2197+ weight_what ,
2198+ weight_pwr ,
2199+ select ,
2200+ maxiter ,
2201+ seed ,
2202+ minimize ,
2203+ opt ,
2204+ already_optimized ,
2205+ node_cost ,
2206+ pbar ,
2207+ ):
2208+ if select == "random" :
2209+ rng = get_rng (seed )
2210+ else :
2211+ rng = None
2212+ if select == "max" :
2213+ i = 0
2214+ elif select == "min" :
2215+ i = - 1
2216+
2217+ candidates , weights = self .calc_subtree_candidates (
2218+ pwr = weight_pwr , what = weight_what
2219+ )
2220+
2221+ r = 0
2222+ while candidates and r < maxiter :
2223+ if rng is not None :
2224+ (i ,) = rng .choices (range (len (candidates )), weights = weights )
2225+
2226+ weights .pop (i )
2227+ sub_root = candidates .pop (i )
2228+
2229+ # get a subtree to possibly reconfigure
2230+ sub_leaves , sub_branches = self .get_subtree (
2231+ sub_root , size = subtree_size , search = subtree_search , seed = seed
2232+ )
2233+
2234+ # check if its already been optimized
2235+ sub_leaves = frozenset (sub_leaves )
2236+ if sub_leaves in already_optimized :
2237+ continue
2238+
2239+ # else remove the branches, keeping track of current cost
2240+ self ._subtree_remove_and_optimize (
2241+ sub_root ,
2242+ sub_leaves ,
2243+ sub_branches ,
2244+ already_optimized ,
2245+ node_cost ,
2246+ minimize ,
2247+ opt ,
2248+ pbar ,
2249+ )
2250+
2251+ # if we have reconfigured simply re-add all candidates
2252+ candidates , weights = self .calc_subtree_candidates (
2253+ pwr = weight_pwr , what = weight_what
2254+ )
2255+
2256+ r += 1
2257+
20972258 def subtree_reconfigure (
20982259 self ,
20992260 subtree_size = 8 ,
@@ -2129,13 +2290,15 @@ def subtree_reconfigure(
21292290 scale their score into a probability: ``score**(1 / weight_pwr)``.
21302291 The larger this is the more explorative the algorithm is when
21312292 ``select='random'``.
2132- select : {'max', 'min', 'random'}, optional
2293+ select : {'descend', ' max', 'min', 'random'}, optional
21332294 What order to select node subtrees to optimize:
21342295
2296+ - 'descend': start from the root and then descend into children. In
2297+ this case the weights and weight_pwr are ignored since this is a
2298+ deterministic order.
21352299 - 'max': choose the highest score first
21362300 - 'min': choose the lowest score first
2137- - 'random': choose randomly weighted on score -- see
2138- ``weight_pwr``.
2301+ - 'random': choose randomly weighted on score - see ``weight_pwr``.
21392302
21402303 maxiter : int, optional
21412304 How many subtree optimizations to perform, the algorithm can
@@ -2161,89 +2324,49 @@ def subtree_reconfigure(
21612324 if minimize is None :
21622325 minimize = self .get_default_objective ()
21632326 scorer = get_score_fn (minimize )
2327+ node_cost = getattr (scorer , "cost_local_tree_node" , lambda _ : 2 )
21642328
21652329 if optimize is None :
21662330 from .pathfinders .path_basic import OptimalOptimizer
21672331
2168- opt = OptimalOptimizer (
2169- minimize = scorer .get_dynamic_programming_minimize ()
2170- )
2332+ minimize = scorer .get_dynamic_programming_minimize ()
2333+ opt = OptimalOptimizer (minimize = minimize )
21712334 else :
21722335 opt = optimize
21732336
2174- node_cost = getattr (scorer , "cost_local_tree_node" , lambda _ : 2 )
2175-
21762337 # different caches as we might want to reconfigure one before other
21772338 tree .already_optimized .setdefault (minimize , set ())
21782339 already_optimized = tree .already_optimized [minimize ]
21792340
2180- if select == "random" :
2181- rng = get_rng (seed )
2182- else :
2183- if select == "max" :
2184- i = 0
2185- elif select == "min" :
2186- i = - 1
2187- rng = None
2188-
2189- candidates , weights = tree .calc_subtree_candidates (
2190- pwr = weight_pwr , what = weight_what
2191- )
2192-
21932341 if progbar :
21942342 import tqdm
21952343
21962344 pbar = tqdm .tqdm ()
21972345 pbar .set_description (_describe_tree (tree ), refresh = False )
2346+ else :
2347+ pbar = None
21982348
2199- r = 0
22002349 try :
2201- while candidates and r < maxiter :
2202- if rng is not None :
2203- (i ,) = rng .choices (range (len (candidates )), weights = weights )
2204-
2205- weights .pop (i )
2206- sub_root = candidates .pop (i )
2207-
2208- # get a subtree to possibly reconfigure
2209- sub_leaves , sub_branches = tree .get_subtree (
2210- sub_root , size = subtree_size , search = subtree_search
2211- )
2212-
2213- sub_leaves = frozenset (sub_leaves )
2214-
2215- # check if its already been optimized
2216- if sub_leaves in already_optimized :
2217- continue
2218-
2219- # else remove the branches, keeping track of current cost
2220- current_cost = node_cost (tree , sub_root )
2221- for node in sub_branches :
2222- if minimize == "size" :
2223- current_cost = max (current_cost , node_cost (tree , node ))
2224- else :
2225- current_cost += node_cost (tree , node )
2226- tree ._remove_node (node )
2227-
2228- # make the optimizer more efficient by supplying accurate cap
2229- opt .cost_cap = max (2 , current_cost )
2230-
2231- # and reoptimize the leaves
2232- tree .contract_nodes (
2233- sub_leaves , optimize = opt , grandparent = sub_root
2234- )
2235- already_optimized .add (sub_leaves )
2236-
2237- r += 1
2350+ reconf_kwargs = {
2351+ "subtree_size" : subtree_size ,
2352+ "subtree_search" : subtree_search ,
2353+ "maxiter" : maxiter ,
2354+ "seed" : seed ,
2355+ "minimize" : minimize ,
2356+ "opt" : opt ,
2357+ "already_optimized" : already_optimized ,
2358+ "node_cost" : node_cost ,
2359+ "pbar" : pbar ,
2360+ }
22382361
2239- if progbar :
2240- pbar .update ()
2241- pbar .set_description (_describe_tree (tree ), refresh = False )
2362+ if select == "descend" :
2363+ tree ._subtree_reconfigure_descend (** reconf_kwargs )
2364+ else :
2365+ reconf_kwargs ["weight_what" ] = weight_what
2366+ reconf_kwargs ["weight_pwr" ] = weight_pwr
2367+ reconf_kwargs ["select" ] = select
2368+ tree ._subtree_reconfigure_rand_select (** reconf_kwargs )
22422369
2243- # if we have reconfigured simply re-add all candidates
2244- candidates , weights = tree .calc_subtree_candidates (
2245- pwr = weight_pwr , what = weight_what
2246- )
22472370 finally :
22482371 if progbar :
22492372 pbar .close ()
0 commit comments