Skip to content

Commit 3db6186

Browse files
committed
subtree reconfigure: add select="descend" mode
1 parent b730ff5 commit 3db6186

2 files changed

Lines changed: 246 additions & 83 deletions

File tree

cotengra/core.py

Lines changed: 193 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,6 +2074,7 @@ def unslice_all(self, inplace=False):
20742074
unslice_all_ = functools.partialmethod(unslice_all, inplace=True)
20752075

20762076
def calc_subtree_candidates(self, pwr=2, what="flops"):
2077+
# get all intermediate nodes
20772078
candidates = list(self.children)
20782079

20792080
if what == "size":
@@ -2082,10 +2083,12 @@ def calc_subtree_candidates(self, pwr=2, what="flops"):
20822083
elif what == "flops":
20832084
weights = [self.get_flops(x) for x in candidates]
20842085

2085-
max_weight = max(weights)
2086-
2087-
# can be bigger than numpy int/float allows
2088-
weights = [float(w / max_weight) ** (1 / pwr) for w in weights]
2086+
if pwr == "log":
2087+
weights = [math.log2(max(2, w)) for w in weights]
2088+
else:
2089+
max_weight = max(weights)
2090+
# can be bigger than numpy int/float allows
2091+
weights = [float(w / max_weight) ** (1 / pwr) for w in weights]
20892092

20902093
# sort by descending score
20912094
candidates, weights = zip(
@@ -2094,6 +2097,164 @@ def calc_subtree_candidates(self, pwr=2, what="flops"):
20942097

20952098
return list(candidates), list(weights)
20962099

2100+
def _subtree_remove_and_optimize(
2101+
self,
2102+
sub_root,
2103+
sub_leaves,
2104+
sub_branches,
2105+
already_optimized,
2106+
node_cost,
2107+
minimize,
2108+
opt,
2109+
pbar,
2110+
):
2111+
current_cost = node_cost(self, sub_root)
2112+
for node in sub_branches:
2113+
# these are the intermediates *between* leaves and sub-root
2114+
if minimize == "size":
2115+
current_cost = max(current_cost, node_cost(self, node))
2116+
else:
2117+
current_cost += node_cost(self, node)
2118+
self._remove_node(node)
2119+
2120+
# make the optimizer more efficient by supplying accurate cap
2121+
opt.cost_cap = max(2, current_cost)
2122+
2123+
# and reoptimize the leaves
2124+
self.contract_nodes(sub_leaves, optimize=opt, grandparent=sub_root)
2125+
already_optimized.add(sub_leaves)
2126+
2127+
if pbar is not None:
2128+
pbar.update()
2129+
pbar.set_description(_describe_tree(self), refresh=False)
2130+
2131+
def _subtree_reconfigure_descend(
2132+
self,
2133+
subtree_size,
2134+
subtree_search,
2135+
maxiter,
2136+
seed,
2137+
minimize,
2138+
opt,
2139+
already_optimized,
2140+
node_cost,
2141+
pbar,
2142+
):
2143+
candidates = [self.root]
2144+
any_modified = False
2145+
2146+
def _possibly_add_children(sub_root, any_modified):
2147+
if self.get_extent(sub_root) > subtree_size:
2148+
# possibly extend with node children, if not close to bottom
2149+
lnode, rnode = self.children[sub_root]
2150+
if self.get_extent(lnode) >= 2:
2151+
candidates.append(lnode)
2152+
if self.get_extent(rnode) >= 2:
2153+
candidates.append(rnode)
2154+
2155+
if len(candidates) == 0:
2156+
# exhausted queue
2157+
if any_modified:
2158+
# but have made *any* changes -> go again from top
2159+
candidates.append(self.root)
2160+
any_modified = False
2161+
2162+
return any_modified
2163+
2164+
r = 0
2165+
while candidates and r < maxiter:
2166+
sub_root = candidates.pop(0)
2167+
2168+
# get a subtree to possibly reconfigure
2169+
sub_leaves, sub_branches = self.get_subtree(
2170+
sub_root, size=subtree_size, search=subtree_search, seed=seed
2171+
)
2172+
2173+
# check if its already been optimized
2174+
sub_leaves = frozenset(sub_leaves)
2175+
if sub_leaves in already_optimized:
2176+
any_modified = _possibly_add_children(sub_root, any_modified)
2177+
continue
2178+
2179+
# else remove the branches, keeping track of current cost
2180+
self._subtree_remove_and_optimize(
2181+
sub_root,
2182+
sub_leaves,
2183+
sub_branches,
2184+
already_optimized,
2185+
node_cost,
2186+
minimize,
2187+
opt,
2188+
pbar,
2189+
)
2190+
any_modified = _possibly_add_children(sub_root, True)
2191+
r += 1
2192+
2193+
def _subtree_reconfigure_rand_select(
2194+
self,
2195+
subtree_size,
2196+
subtree_search,
2197+
weight_what,
2198+
weight_pwr,
2199+
select,
2200+
maxiter,
2201+
seed,
2202+
minimize,
2203+
opt,
2204+
already_optimized,
2205+
node_cost,
2206+
pbar,
2207+
):
2208+
if select == "random":
2209+
rng = get_rng(seed)
2210+
else:
2211+
rng = None
2212+
if select == "max":
2213+
i = 0
2214+
elif select == "min":
2215+
i = -1
2216+
2217+
candidates, weights = self.calc_subtree_candidates(
2218+
pwr=weight_pwr, what=weight_what
2219+
)
2220+
2221+
r = 0
2222+
while candidates and r < maxiter:
2223+
if rng is not None:
2224+
(i,) = rng.choices(range(len(candidates)), weights=weights)
2225+
2226+
weights.pop(i)
2227+
sub_root = candidates.pop(i)
2228+
2229+
# get a subtree to possibly reconfigure
2230+
sub_leaves, sub_branches = self.get_subtree(
2231+
sub_root, size=subtree_size, search=subtree_search, seed=seed
2232+
)
2233+
2234+
# check if its already been optimized
2235+
sub_leaves = frozenset(sub_leaves)
2236+
if sub_leaves in already_optimized:
2237+
continue
2238+
2239+
# else remove the branches, keeping track of current cost
2240+
self._subtree_remove_and_optimize(
2241+
sub_root,
2242+
sub_leaves,
2243+
sub_branches,
2244+
already_optimized,
2245+
node_cost,
2246+
minimize,
2247+
opt,
2248+
pbar,
2249+
)
2250+
2251+
# if we have reconfigured simply re-add all candidates
2252+
candidates, weights = self.calc_subtree_candidates(
2253+
pwr=weight_pwr, what=weight_what
2254+
)
2255+
2256+
r += 1
2257+
20972258
def subtree_reconfigure(
20982259
self,
20992260
subtree_size=8,
@@ -2129,13 +2290,15 @@ def subtree_reconfigure(
21292290
scale their score into a probability: ``score**(1 / weight_pwr)``.
21302291
The larger this is the more explorative the algorithm is when
21312292
``select='random'``.
2132-
select : {'max', 'min', 'random'}, optional
2293+
select : {'descend', 'max', 'min', 'random'}, optional
21332294
What order to select node subtrees to optimize:
21342295
2296+
- 'descend': start from the root and then descend into children. In
2297+
this case the weights and weight_pwr are ignored since this is a
2298+
deterministic order.
21352299
- 'max': choose the highest score first
21362300
- 'min': choose the lowest score first
2137-
- 'random': choose randomly weighted on score -- see
2138-
``weight_pwr``.
2301+
- 'random': choose randomly weighted on score - see ``weight_pwr``.
21392302
21402303
maxiter : int, optional
21412304
How many subtree optimizations to perform, the algorithm can
@@ -2161,89 +2324,49 @@ def subtree_reconfigure(
21612324
if minimize is None:
21622325
minimize = self.get_default_objective()
21632326
scorer = get_score_fn(minimize)
2327+
node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2)
21642328

21652329
if optimize is None:
21662330
from .pathfinders.path_basic import OptimalOptimizer
21672331

2168-
opt = OptimalOptimizer(
2169-
minimize=scorer.get_dynamic_programming_minimize()
2170-
)
2332+
minimize = scorer.get_dynamic_programming_minimize()
2333+
opt = OptimalOptimizer(minimize=minimize)
21712334
else:
21722335
opt = optimize
21732336

2174-
node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2)
2175-
21762337
# different caches as we might want to reconfigure one before other
21772338
tree.already_optimized.setdefault(minimize, set())
21782339
already_optimized = tree.already_optimized[minimize]
21792340

2180-
if select == "random":
2181-
rng = get_rng(seed)
2182-
else:
2183-
if select == "max":
2184-
i = 0
2185-
elif select == "min":
2186-
i = -1
2187-
rng = None
2188-
2189-
candidates, weights = tree.calc_subtree_candidates(
2190-
pwr=weight_pwr, what=weight_what
2191-
)
2192-
21932341
if progbar:
21942342
import tqdm
21952343

21962344
pbar = tqdm.tqdm()
21972345
pbar.set_description(_describe_tree(tree), refresh=False)
2346+
else:
2347+
pbar = None
21982348

2199-
r = 0
22002349
try:
2201-
while candidates and r < maxiter:
2202-
if rng is not None:
2203-
(i,) = rng.choices(range(len(candidates)), weights=weights)
2204-
2205-
weights.pop(i)
2206-
sub_root = candidates.pop(i)
2207-
2208-
# get a subtree to possibly reconfigure
2209-
sub_leaves, sub_branches = tree.get_subtree(
2210-
sub_root, size=subtree_size, search=subtree_search
2211-
)
2212-
2213-
sub_leaves = frozenset(sub_leaves)
2214-
2215-
# check if its already been optimized
2216-
if sub_leaves in already_optimized:
2217-
continue
2218-
2219-
# else remove the branches, keeping track of current cost
2220-
current_cost = node_cost(tree, sub_root)
2221-
for node in sub_branches:
2222-
if minimize == "size":
2223-
current_cost = max(current_cost, node_cost(tree, node))
2224-
else:
2225-
current_cost += node_cost(tree, node)
2226-
tree._remove_node(node)
2227-
2228-
# make the optimizer more efficient by supplying accurate cap
2229-
opt.cost_cap = max(2, current_cost)
2230-
2231-
# and reoptimize the leaves
2232-
tree.contract_nodes(
2233-
sub_leaves, optimize=opt, grandparent=sub_root
2234-
)
2235-
already_optimized.add(sub_leaves)
2236-
2237-
r += 1
2350+
reconf_kwargs = {
2351+
"subtree_size": subtree_size,
2352+
"subtree_search": subtree_search,
2353+
"maxiter": maxiter,
2354+
"seed": seed,
2355+
"minimize": minimize,
2356+
"opt": opt,
2357+
"already_optimized": already_optimized,
2358+
"node_cost": node_cost,
2359+
"pbar": pbar,
2360+
}
22382361

2239-
if progbar:
2240-
pbar.update()
2241-
pbar.set_description(_describe_tree(tree), refresh=False)
2362+
if select == "descend":
2363+
tree._subtree_reconfigure_descend(**reconf_kwargs)
2364+
else:
2365+
reconf_kwargs["weight_what"] = weight_what
2366+
reconf_kwargs["weight_pwr"] = weight_pwr
2367+
reconf_kwargs["select"] = select
2368+
tree._subtree_reconfigure_rand_select(**reconf_kwargs)
22422369

2243-
# if we have reconfigured simply re-add all candidates
2244-
candidates, weights = tree.calc_subtree_candidates(
2245-
pwr=weight_pwr, what=weight_what
2246-
)
22472370
finally:
22482371
if progbar:
22492372
pbar.close()

0 commit comments

Comments
 (0)