Skip to content

Commit 8c48b01

Browse files
authored
Make contraction label derivation type-stable (#192)
## Summary Derives the bipartitioned permutations for `contract` as concretely-typed tuples behind a single `Val` function-barrier on the contracted count, rather than building them by converting runtime-length `Vector`s into abstract-typed tuples. The abstract tuples defeated the type stability of the downstream `matricize` pipeline, which is built around statically-sized tuples and `Val` barriers, so the per-call bookkeeping for a small contraction boxed heavily. This is a second pass on the label-derivation layer, after #190 removed the `Set` and `Dict` hashing from the same path. For a 4x4 matrix multiply the per-call cost drops from 51 allocations to 17 and is several times faster, landing just above the floor set by the lower-level entry point that takes the permutations directly. The remaining allocations are the output array temporaries and the destination-label `Vector`. The public `biperms` and `biperm` signatures are unchanged. The type-stable core is a new `Val`-parameterized method of `biperms`, and the two label entry points cross into `Val`-specialized helpers so the contraction itself runs type-stably.
1 parent e9318c9 commit 8c48b01

5 files changed

Lines changed: 93 additions & 40 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.13.2"
3+
version = "0.13.3"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ Mooncake.tangent_type(::Type{<:ContractAlgorithm}) = Mooncake.NoTangent
1313
}
1414
@zero_derivative DefaultCtx Tuple{typeof(biperm), Any, Any, Any}
1515
@zero_derivative DefaultCtx Tuple{typeof(biperms), typeof(contract), Any, Any, Any}
16+
@zero_derivative DefaultCtx Tuple{
17+
typeof(biperms),
18+
typeof(contract),
19+
Val,
20+
Any,
21+
Any,
22+
Any,
23+
Any,
24+
}
1625
@zero_derivative DefaultCtx Tuple{
1726
typeof(check_input), typeof(contract), Any, Any, Any, Any, Any, Any,
1827
}

src/contract/biperms.jl

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# `a ∖ b` and `a ∩ b` as a `Vector`, preserving the order of `a`, via a linear
2-
# scan. For the small collections here `Base.setdiff`/`intersect` are slower
3-
# because they build a `Set` and hash. Both assume set-like (unique) inputs.
1+
import TupleTools
2+
3+
# `a ∖ b` as a `Vector`, preserving the order of `a`, via a linear scan. For the small
4+
# collections here `Base.setdiff` is slower because it builds a `Set` and hashes; it
5+
# assumes set-like (unique) inputs. Used to assemble the destination labels in
6+
# `contract_labels`.
47
smallsetdiff(a, b) = [x for x in a if x b]
5-
smallintersect(a, b) = [x for x in a if x b]
68

79
# Position of each element of `x` in `y`, as a tuple. Linear scan, no hashing
810
# (`Base.indexin` builds a `Dict`), for the small collections here.
@@ -32,24 +34,41 @@ length_domain(t) = 0
3234

3335
length_codomain(t) = length(t) - length_domain(t)
3436

37+
# `findfirst` for a match the caller guarantees exists, so the result is an `Int` rather
38+
# than `Union{Int, Nothing}` (the `Nothing` would otherwise break inference downstream).
39+
checked_findfirst(pred, collection) = something(findfirst(pred, collection))
40+
3541
# codomain <-- domain
36-
function biperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
37-
codomain = Tuple(smallsetdiff(dimnames1, dimnames2))
38-
contracted = Tuple(smallintersect(dimnames1, dimnames2))
39-
domain = Tuple(smallsetdiff(dimnames2, dimnames1))
40-
41-
# `codomain`/`contracted` and `contracted`/`domain` partition the operands by
42-
# construction, so the only label consistency left to check is that the
43-
# destination carries exactly the uncontracted labels. `biperm` below then
44-
# checks each group lands in the destination.
45-
length(codomain) + length(domain) == length(dimnames_dest) ||
42+
function biperms(::typeof(contract), labels_dest, labels1, labels2)
43+
t1, t2 = Tuple(labels1), Tuple(labels2)
44+
contracted1 = map(in(t2), t1)
45+
return biperms(contract, Val(count(contracted1)), labels_dest, t1, t2, contracted1)
46+
end
47+
# `K` is the number of contracted labels. Passing it as a `Val` makes the group sizes
48+
# compile-time constants, so the permutations below come out as concretely-typed tuples and
49+
# the rest of the contraction stays type-stable. `contracted1` is the boolean mask of which
50+
# of `labels1`'s labels are contracted (its `count` is `K`), threaded in from the caller.
51+
function biperms(
52+
::typeof(contract), ::Val{K}, labels_dest, labels1, labels2, contracted1
53+
) where {K}
54+
n1, n2 = length(labels1), length(labels2)
55+
# `sortperm` of the boolean mask is a stable partition: uncontracted (`false`) indices
56+
# first, contracted (`true`) indices last, each in their original order.
57+
perm1_codomain, perm1_domain =
58+
bipartition(TupleTools.sortperm(contracted1), Val(n1 - K))
59+
perm2_domain, _ =
60+
bipartition(TupleTools.sortperm(map(in(labels1), labels2)), Val(n2 - K))
61+
# Align the contracted groups: list operand 2's contracted labels in operand 1's order.
62+
perm2_codomain = map(p -> checked_findfirst(==(labels1[p]), labels2), perm1_domain)
63+
# The operands partition into (un)contracted groups by construction; the only label
64+
# consistency left to check is that the destination carries exactly the uncontracted
65+
# labels. Locating each below then checks they all land in the destination.
66+
length(labels_dest) == (n1 - K) + (n2 - K) ||
4667
throw(ArgumentError("Invalid contraction labels"))
47-
48-
perm_codomain_dest, perm_domain_dest = biperm(dimnames_dest, codomain, domain)
49-
invperm_dest = invperm((perm_codomain_dest..., perm_domain_dest...))
50-
biperm_dest = bipartition(invperm_dest, Val(length(codomain)))
51-
52-
biperm1 = biperm(dimnames1, codomain, contracted)
53-
biperm2 = biperm(dimnames2, contracted, domain)
54-
return biperm_dest, biperm1, biperm2
68+
pos_dest = (
69+
map(p -> checked_findfirst(==(labels1[p]), labels_dest), perm1_codomain)...,
70+
map(p -> checked_findfirst(==(labels2[p]), labels_dest), perm2_domain)...,
71+
)
72+
biperm_dest = bipartition(invperm(pos_dest), Val(n1 - K))
73+
return biperm_dest, (perm1_codomain, perm1_domain), (perm2_codomain, perm2_domain)
5574
end

src/contract/contract.jl

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,30 @@ end
99
function contract(
1010
labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
1111
)
12-
(perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain),
13-
(perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2)
14-
return contract(
15-
perm_dest_codomain, perm_dest_domain,
16-
a1, perm1_codomain, perm1_domain,
17-
a2, perm2_codomain, perm2_domain;
12+
t1 = ntuple(i -> labels1[i], Val(ndims(a1)))
13+
t2 = ntuple(i -> labels2[i], Val(ndims(a2)))
14+
contracted1 = map(in(t2), t1)
15+
# Cross into a `Val(K)` method (a function-barrier on the contracted count) so the
16+
# bipartitioned permutations and the contraction below them are type-stable.
17+
return _contract(
18+
Val(count(contracted1)),
19+
labels_dest,
20+
a1,
21+
t1,
22+
a2,
23+
t2,
24+
contracted1;
1825
kwargs...
1926
)
2027
end
28+
function _contract(
29+
::Val{K}, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2,
30+
contracted1; kwargs...
31+
) where {K}
32+
biperm_dest, biperm1, biperm2 =
33+
biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1)
34+
return contract(biperm_dest..., a1, biperm1..., a2, biperm2...; kwargs...)
35+
end
2136

2237
# contract (bipartitioned permutations)
2338
function contract(
@@ -117,13 +132,25 @@ function contractopadd!(
117132
α::Number, β::Number;
118133
kwargs...
119134
)
120-
(perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain),
121-
(perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2)
135+
t1 = ntuple(i -> labels1[i], Val(ndims(a1)))
136+
t2 = ntuple(i -> labels2[i], Val(ndims(a2)))
137+
contracted1 = map(in(t2), t1)
138+
# Cross into a `Val(K)` method (a function-barrier on the contracted count) so the
139+
# bipartitioned permutations and the contraction below them are type-stable.
140+
return _contractopadd!(
141+
Val(count(contracted1)), a_dest, labels_dest,
142+
op1, a1, t1, op2, a2, t2, α, β, contracted1; kwargs...
143+
)
144+
end
145+
function _contractopadd!(
146+
::Val{K}, a_dest::AbstractArray, labels_dest,
147+
op1, a1::AbstractArray, labels1, op2, a2::AbstractArray, labels2,
148+
α::Number, β::Number, contracted1; kwargs...
149+
) where {K}
150+
biperm_dest, biperm1, biperm2 =
151+
biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1)
122152
return contractopadd!(
123-
a_dest, perm_dest_codomain, perm_dest_domain,
124-
op1, a1, perm1_codomain, perm1_domain,
125-
op2, a2, perm2_codomain, perm2_domain,
126-
α, β; kwargs...
153+
a_dest, biperm_dest..., op1, a1, biperm1..., op2, a2, biperm2..., α, β; kwargs...
127154
)
128155
end
129156
# contractopadd! (bipartitioned permutations, algorithm selection)

test/test_setoperations.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
using TensorAlgebra: biperms, contract, smallintersect, smallsetdiff, tuple_indexin
1+
using TensorAlgebra: biperms, contract, smallsetdiff, tuple_indexin
22
using Test: @test, @test_throws, @testset
33

4-
@testset "smallsetdiff/smallintersect" begin
4+
@testset "smallsetdiff" begin
55
# Order-preserving, returning a `Vector`.
66
@test smallsetdiff((:i, :j, :k), (:k, :i)) == [:j]
7-
@test smallintersect((:i, :j, :k), (:k, :i)) == [:i, :k]
87
@test smallsetdiff([:i, :j, :k], [:k, :i]) == [:j]
9-
@test smallintersect([:i, :j, :k], [:k, :i]) == [:i, :k]
108
# Disjoint and empty cases.
119
@test smallsetdiff((:i, :j), ()) == [:i, :j]
12-
@test smallintersect((:i, :j), (:k,)) == []
10+
@test smallsetdiff((:i, :j), (:i, :j)) == []
1311
end
1412

1513
@testset "tuple_indexin" begin

0 commit comments

Comments
 (0)