|
1 | | -# `a ∖ b` and `a ∩ b` as a `Vector`, preserving the order of `a`, via a linear |
2 | | -# scan. For the small collections here `Base.setdiff`/`intersect` are slower |
3 | | -# because they build a `Set` and hash. Both assume set-like (unique) inputs. |
| 1 | +import TupleTools |
| 2 | + |
| 3 | +# `a ∖ b` as a `Vector`, preserving the order of `a`, via a linear scan. For the small |
| 4 | +# collections here `Base.setdiff` is slower because it builds a `Set` and hashes; it |
| 5 | +# assumes set-like (unique) inputs. Used to assemble the destination labels in |
| 6 | +# `contract_labels`. |
4 | 7 | smallsetdiff(a, b) = [x for x in a if x ∉ b] |
5 | | -smallintersect(a, b) = [x for x in a if x ∈ b] |
6 | 8 |
|
7 | 9 | # Position of each element of `x` in `y`, as a tuple. Linear scan, no hashing |
8 | 10 | # (`Base.indexin` builds a `Dict`), for the small collections here. |
@@ -32,24 +34,41 @@ length_domain(t) = 0 |
32 | 34 |
|
33 | 35 | length_codomain(t) = length(t) - length_domain(t) |
34 | 36 |
|
| 37 | +# `findfirst` for a match the caller guarantees exists, so the result is an `Int` rather |
| 38 | +# than `Union{Int, Nothing}` (the `Nothing` would otherwise break inference downstream). |
| 39 | +checked_findfirst(pred, collection) = something(findfirst(pred, collection)) |
| 40 | + |
35 | 41 | # codomain <-- domain |
36 | | -function biperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) |
37 | | - codomain = Tuple(smallsetdiff(dimnames1, dimnames2)) |
38 | | - contracted = Tuple(smallintersect(dimnames1, dimnames2)) |
39 | | - domain = Tuple(smallsetdiff(dimnames2, dimnames1)) |
40 | | - |
41 | | - # `codomain`/`contracted` and `contracted`/`domain` partition the operands by |
42 | | - # construction, so the only label consistency left to check is that the |
43 | | - # destination carries exactly the uncontracted labels. `biperm` below then |
44 | | - # checks each group lands in the destination. |
45 | | - length(codomain) + length(domain) == length(dimnames_dest) || |
| 42 | +function biperms(::typeof(contract), labels_dest, labels1, labels2) |
| 43 | + t1, t2 = Tuple(labels1), Tuple(labels2) |
| 44 | + contracted1 = map(in(t2), t1) |
| 45 | + return biperms(contract, Val(count(contracted1)), labels_dest, t1, t2, contracted1) |
| 46 | +end |
| 47 | +# `K` is the number of contracted labels. Passing it as a `Val` makes the group sizes |
| 48 | +# compile-time constants, so the permutations below come out as concretely-typed tuples and |
| 49 | +# the rest of the contraction stays type-stable. `contracted1` is the boolean mask of which |
| 50 | +# of `labels1`'s labels are contracted (its `count` is `K`), threaded in from the caller. |
| 51 | +function biperms( |
| 52 | + ::typeof(contract), ::Val{K}, labels_dest, labels1, labels2, contracted1 |
| 53 | + ) where {K} |
| 54 | + n1, n2 = length(labels1), length(labels2) |
| 55 | + # `sortperm` of the boolean mask is a stable partition: uncontracted (`false`) indices |
| 56 | + # first, contracted (`true`) indices last, each in their original order. |
| 57 | + perm1_codomain, perm1_domain = |
| 58 | + bipartition(TupleTools.sortperm(contracted1), Val(n1 - K)) |
| 59 | + perm2_domain, _ = |
| 60 | + bipartition(TupleTools.sortperm(map(in(labels1), labels2)), Val(n2 - K)) |
| 61 | + # Align the contracted groups: list operand 2's contracted labels in operand 1's order. |
| 62 | + perm2_codomain = map(p -> checked_findfirst(==(labels1[p]), labels2), perm1_domain) |
| 63 | + # The operands partition into (un)contracted groups by construction; the only label |
| 64 | + # consistency left to check is that the destination carries exactly the uncontracted |
| 65 | + # labels. Locating each below then checks they all land in the destination. |
| 66 | + length(labels_dest) == (n1 - K) + (n2 - K) || |
46 | 67 | throw(ArgumentError("Invalid contraction labels")) |
47 | | - |
48 | | - perm_codomain_dest, perm_domain_dest = biperm(dimnames_dest, codomain, domain) |
49 | | - invperm_dest = invperm((perm_codomain_dest..., perm_domain_dest...)) |
50 | | - biperm_dest = bipartition(invperm_dest, Val(length(codomain))) |
51 | | - |
52 | | - biperm1 = biperm(dimnames1, codomain, contracted) |
53 | | - biperm2 = biperm(dimnames2, contracted, domain) |
54 | | - return biperm_dest, biperm1, biperm2 |
| 68 | + pos_dest = ( |
| 69 | + map(p -> checked_findfirst(==(labels1[p]), labels_dest), perm1_codomain)..., |
| 70 | + map(p -> checked_findfirst(==(labels2[p]), labels_dest), perm2_domain)..., |
| 71 | + ) |
| 72 | + biperm_dest = bipartition(invperm(pos_dest), Val(n1 - K)) |
| 73 | + return biperm_dest, (perm1_codomain, perm1_domain), (perm2_codomain, perm2_domain) |
55 | 74 | end |
0 commit comments