Skip to content

Commit 8145c68

Browse files
mtfishmanclaude
andauthored
Add contractopadd! and matricizeop (#158)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4339081 commit 8145c68

12 files changed

Lines changed: 343 additions & 88 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.8.0"
3+
version = "0.9.0"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ path = ".."
1111
Documenter = "1.8.1"
1212
ITensorFormatter = "0.2.27"
1313
Literate = "2.20.1"
14-
TensorAlgebra = "0.8"
14+
TensorAlgebra = "0.9"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
55
path = ".."
66

77
[compat]
8-
TensorAlgebra = "0.8"
8+
TensorAlgebra = "0.9"

ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,35 +52,22 @@ function TA.contract(
5252
end
5353

5454
# in-place
55-
function TA.contractadd!(
55+
function TA.contractopadd!(
5656
algorithm::TensorOperationsAlgorithm,
5757
a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain,
58-
a1::AbstractArray, perm1_codomain, perm1_domain,
59-
a2::AbstractArray, perm2_codomain, perm2_domain,
58+
op1, a1::AbstractArray, perm1_codomain, perm1_domain,
59+
op2, a2::AbstractArray, perm2_codomain, perm2_domain,
6060
α::Number, β::Number
6161
)
6262
permblocks1 = Tuple.((perm1_codomain, perm1_domain))
6363
permblocks2 = Tuple.((perm2_codomain, perm2_domain))
6464
permblocks_dest = Tuple.((perm_dest_codomain, perm_dest_domain))
65-
conj1, conj2 = false, false
66-
return TO.tensorcontract!(
67-
a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2,
68-
permblocks_dest, α, β, algorithm.backend
69-
)
70-
end
71-
72-
function TA.contractadd!(
73-
algorithm::TensorOperationsAlgorithm,
74-
a_dest::AbstractArray, labels_dest,
75-
a1::AbstractArray, labels1,
76-
a2::AbstractArray, labels2,
77-
α::Number, β::Number
78-
)
79-
permblocks1, permblocks2, permblocks_dest =
80-
TO.contract_indices(labels1, labels2, labels_dest)
81-
conj1, conj2 = false, false
65+
conj1 = op1 === conj
66+
conj2 = op2 === conj
67+
a1′ = (op1 === identity || op1 === conj) ? a1 : op1.(a1)
68+
a2′ = (op2 === identity || op2 === conj) ? a2 : op2.(a2)
8269
return TO.tensorcontract!(
83-
a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2,
70+
a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2,
8471
permblocks_dest, α, β, algorithm.backend
8572
)
8673
end
@@ -96,14 +83,13 @@ function TO.tensorcontract!(
9683
backend::TA.ContractAlgorithm,
9784
allocator
9885
)
99-
# TODO: FIXME: Use `conjed` to do the conjugation lazily.
100-
a1′ = conj1 ? conj(a1) : a1
101-
a2′ = conj2 ? conj(a2) : a2
102-
return TA.contractadd!(
86+
op1 = conj1 ? conj : identity
87+
op2 = conj2 ? conj : identity
88+
return TA.contractopadd!(
10389
backend,
10490
a_dest, permblocks_dest...,
105-
a1′, permblocks1...,
106-
a2′, permblocks2...,
91+
op1, a1, permblocks1...,
92+
op2, a2, permblocks2...,
10793
α, β
10894
)
10995
end

src/TensorAlgebra.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ module TensorAlgebra
33
export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar,
44
lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals
55

6+
if VERSION >= v"1.11.0-DEV.469"
7+
eval(Meta.parse("public contractopadd!, matricizeop"))
8+
end
9+
610
include("MatrixAlgebra.jl")
711
include("blockedtuple.jl")
812
include("blockedpermutation.jl")

src/contract/contract.jl

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,48 @@ function contractadd!(
8787
α::Number, β::Number;
8888
kwargs...
8989
)
90-
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
91-
return contractadd!(
92-
a_dest, blocks(biperm_dest)...,
93-
a1, blocks(biperm1)...,
94-
a2, blocks(biperm2)...,
95-
α, β; kwargs...
90+
return contractopadd!(
91+
a_dest, labels_dest, identity, a1, labels1, identity, a2, labels2, α, β; kwargs...
9692
)
9793
end
94+
# contractadd! (bipartitioned permutations)
9895
function contractadd!(
9996
a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain,
10097
a1::AbstractArray, perm1_codomain, perm1_domain,
10198
a2::AbstractArray, perm2_codomain, perm2_domain,
10299
α::Number, β::Number;
100+
kwargs...
101+
)
102+
return contractopadd!(
103+
a_dest, perm_dest_codomain, perm_dest_domain,
104+
identity, a1, perm1_codomain, perm1_domain,
105+
identity, a2, perm2_codomain, perm2_domain,
106+
α, β; kwargs...
107+
)
108+
end
109+
110+
# contractopadd! (labels)
111+
function contractopadd!(
112+
a_dest::AbstractArray, labels_dest,
113+
op1, a1::AbstractArray, labels1,
114+
op2, a2::AbstractArray, labels2,
115+
α::Number, β::Number;
116+
kwargs...
117+
)
118+
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
119+
return contractopadd!(
120+
a_dest, blocks(biperm_dest)...,
121+
op1, a1, blocks(biperm1)...,
122+
op2, a2, blocks(biperm2)...,
123+
α, β; kwargs...
124+
)
125+
end
126+
# contractopadd! (bipartitioned permutations, algorithm selection)
127+
function contractopadd!(
128+
a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain,
129+
op1, a1::AbstractArray, perm1_codomain, perm1_domain,
130+
op2, a2::AbstractArray, perm2_codomain, perm2_domain,
131+
α::Number, β::Number;
103132
alg = DefaultContractAlgorithm(), kwargs...
104133
)
105134
check_input(
@@ -109,38 +138,38 @@ function contractadd!(
109138
a2, perm2_codomain, perm2_domain
110139
)
111140
algorithm = select_contract_algorithm(alg, a1, a2; kwargs...)
112-
return contractadd!(
141+
return contractopadd!(
113142
algorithm,
114143
a_dest, perm_dest_codomain, perm_dest_domain,
115-
a1, perm1_codomain, perm1_domain,
116-
a2, perm2_codomain, perm2_domain,
144+
op1, a1, perm1_codomain, perm1_domain,
145+
op2, a2, perm2_codomain, perm2_domain,
117146
α, β
118147
)
119148
end
120-
# contractadd! (dispatched on the algorithm, bipartitioned permutations)
149+
# contractopadd! (dispatched on the algorithm, bipartitioned permutations)
121150
# Required interface if not using matricized contraction
122-
function contractadd!(
151+
function contractopadd!(
123152
algorithm::ContractAlgorithm,
124153
a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain,
125-
a1::AbstractArray, perm1_codomain, perm1_domain,
126-
a2::AbstractArray, perm2_codomain, perm2_domain,
154+
op1, a1::AbstractArray, perm1_codomain, perm1_domain,
155+
op2, a2::AbstractArray, perm2_codomain, perm2_domain,
127156
α::Number, β::Number
128157
)
129158
return throw(
130159
MethodError(
131-
contractadd!,
160+
contractopadd!,
132161
(
133162
algorithm,
134163
a_dest, perm_dest_codomain, perm_dest_domain,
135-
a1, perm1_codomain, perm1_domain,
136-
a2, perm2_codomain, perm2_domain,
164+
op1, a1, perm1_codomain, perm1_domain,
165+
op2, a2, perm2_codomain, perm2_domain,
137166
α, β,
138167
)
139168
)
140169
)
141170
end
142171

143-
# BlockPermutation versions of contract[add][!]
172+
# BlockPermutation versions of contract[opadd][!]
144173
function contract(
145174
a1::AbstractArray, biperm1::AbstractBlockPermutation{2},
146175
a2::AbstractArray, biperm2::AbstractBlockPermutation{2};
@@ -187,18 +216,16 @@ function contractadd!(
187216
α, β; kwargs...
188217
)
189218
end
190-
function contractadd!(
191-
algorithm::ContractAlgorithm,
219+
function contractopadd!(
192220
a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2},
193-
a1::AbstractArray, biperm1::AbstractBlockPermutation{2},
194-
a2::AbstractArray, biperm2::AbstractBlockPermutation{2},
195-
α::Number, β::Number
221+
op1, a1::AbstractArray, biperm1::AbstractBlockPermutation{2},
222+
op2, a2::AbstractArray, biperm2::AbstractBlockPermutation{2},
223+
α::Number, β::Number; kwargs...
196224
)
197-
return contractadd!(
198-
algorithm,
225+
return contractopadd!(
199226
a_dest, blocks(biperm_dest)...,
200-
a1, blocks(biperm1)...,
201-
a2, blocks(biperm2)...,
202-
α, β
227+
op1, a1, blocks(biperm1)...,
228+
op2, a2, blocks(biperm2)...,
229+
α, β; kwargs...
203230
)
204231
end

src/contract/contract_matricize.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
using LinearAlgebra: mul!
22

3-
function contractadd!(
3+
function contractopadd!(
44
algorithm::Matricize,
55
a_dest::AbstractArray, biperm_dest_codomain, biperm_dest_domain,
6-
a1::AbstractArray, biperm1_codomain, biperm1_domain,
7-
a2::AbstractArray, biperm2_codomain, biperm2_domain,
6+
op1, a1::AbstractArray, biperm1_codomain, biperm1_domain,
7+
op2, a2::AbstractArray, biperm2_codomain, biperm2_domain,
88
α::Number, β::Number
99
)
10-
return contractadd!_matricize(
10+
return contractopadd!_matricize(
1111
algorithm,
1212
a_dest, biperm_dest_codomain, biperm_dest_domain,
13-
a1, biperm1_codomain, biperm1_domain,
14-
a2, biperm2_codomain, biperm2_domain,
13+
op1, a1, biperm1_codomain, biperm1_domain,
14+
op2, a2, biperm2_codomain, biperm2_domain,
1515
α, β
1616
)
1717
end
1818

19-
function contractadd!_matricize(
19+
function contractopadd!_matricize(
2020
algorithm::Matricize,
2121
a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain,
22-
a1::AbstractArray, perm1_codomain, perm1_domain,
23-
a2::AbstractArray, perm2_codomain, perm2_domain,
22+
op1, a1::AbstractArray, perm1_codomain, perm1_domain,
23+
op2, a2::AbstractArray, perm2_codomain, perm2_domain,
2424
α::Number, β::Number
2525
)
2626
perm_dest = (perm_dest_codomain..., perm_dest_domain...)
@@ -32,8 +32,8 @@ function contractadd!_matricize(
3232
a1, perm1_codomain, perm1_domain,
3333
a2, perm2_codomain, perm2_domain
3434
)
35-
a1_mat = matricize(algorithm.fusion_style, a1, perm1_codomain, perm1_domain)
36-
a2_mat = matricize(algorithm.fusion_style, a2, perm2_codomain, perm2_domain)
35+
a1_mat = matricizeop(algorithm.fusion_style, op1, a1, perm1_codomain, perm1_domain)
36+
a2_mat = matricizeop(algorithm.fusion_style, op2, a2, perm2_codomain, perm2_domain)
3737
a_dest_mat = a1_mat * a2_mat
3838
unmatricizeadd!(
3939
algorithm.fusion_style, a_dest, a_dest_mat, invperm_codomain, invperm_domain, α, β

src/matricize.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,40 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val)
123123
return matricize_axes(FusionStyle(a), a, ndims_codomain)
124124
end
125125

126+
# Default similar with bipartitioned axes: flatten to a plain tuple of axes.
127+
# Downstream types (e.g., FusionTensor) can override to preserve bipartition.
128+
function Base.similar(a::AbstractArray, T::Type, axes::BlockedTuple{2})
129+
return similar(a, T, Tuple(axes))
130+
end
131+
132+
"""
133+
permutedimsop(op, src, perm_codomain, perm_domain)
134+
135+
Non-mutating version of `bipermutedimsopadd!`: returns
136+
`op.(permutedims(src, (perm_codomain..., perm_domain...)))`.
137+
"""
138+
function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain)
139+
dest = allocate_output(permutedimsop, op, src, perm_codomain, perm_domain)
140+
return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false)
141+
end
142+
143+
function allocate_output(::typeof(permutedimsop), op, src::AbstractArray, perm_co, perm_do)
144+
T = Base.promote_op(op, eltype(src))
145+
axes_co = map(i -> axes(src, i), perm_co)
146+
axes_do = map(i -> axes(src, i), perm_do)
147+
return similar(src, T, tuplemortar((axes_co, axes_do)))
148+
end
149+
126150
# Inner version takes a list of sub-permutations, overload this one if needed.
127151
# TODO: Remove _permutedims once support for Julia 1.10 is dropped
128152
# define permutedims with a BlockedPermuation. Default is to flatten it.
129153
# TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`.
130154
# Keeping it here for backwards compatibility.
131155
function bipermutedims(a::AbstractArray, perm1, perm2)
132-
return _permutedims(a, (perm1..., perm2...))
156+
return permutedimsop(identity, a, perm1, perm2)
133157
end
134158
function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2)
135-
return _permutedims!(a_dest, a_src, (perm1..., perm2...))
159+
return bipermutedimsopadd!(a_dest, identity, a_src, perm1, perm2, true, false)
136160
end
137161
function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2})
138162
return bipermutedims(a, blocks(biperm)...)
@@ -165,15 +189,14 @@ function matricize(
165189
)
166190
return matricize(FusionStyle(a), a, perm_codomain, perm_domain)
167191
end
168-
# This is a more advanced version to overload where the permutation is actually performed.
192+
# Thin wrapper around `matricizeop` with identity op — the actual matricization logic
193+
# (and the fusion-style overload point for folding ops into matricization) lives in
194+
# `matricizeop`.
169195
function matricize(
170196
style::FusionStyle, a::AbstractArray,
171197
perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
172198
)
173-
ndims(a) == length(perm_codomain) + length(perm_domain) ||
174-
throw(ArgumentError("Invalid bipermutation"))
175-
a_perm = bipermutedims(a, perm_codomain, perm_domain)
176-
return matricize(style, a_perm, Val(length(perm_codomain)))
199+
return matricizeop(style, identity, a, perm_codomain, perm_domain)
177200
end
178201

179202
# Process inputs such as `EllipsisNotation.Ellipsis`.
@@ -218,6 +241,39 @@ function matricize(
218241
return matricize(style, a, blocks(biperm_dest)...)
219242
end
220243

244+
# ==================================== matricizeop =======================================
245+
246+
"""
247+
matricizeop(op, a, perm_codomain, perm_domain)
248+
249+
Matricize `a` with element-wise operation `op` folded in. Returns a matrix representing
250+
`op.(matricize(a, perm_codomain, perm_domain))`.
251+
252+
Has "maybe alias" semantics: the result may be a view/wrapper aliasing `a` or a fresh
253+
copy, depending on the fusion style and array type. The caller should treat the result
254+
as read-only.
255+
"""
256+
function matricizeop(op, a::AbstractArray, perm_codomain, perm_domain)
257+
return matricizeop(FusionStyle(a), op, a, perm_codomain, perm_domain)
258+
end
259+
function matricizeop(
260+
style::FusionStyle, op, a::AbstractArray, perm_codomain, perm_domain
261+
)
262+
return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...)
263+
end
264+
# This is the primary function that should be overloaded for new fusion styles to fold
265+
# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy
266+
# wrappers like StridedView with op metadata for zero-copy).
267+
function matricizeop(
268+
style::FusionStyle, op, a::AbstractArray,
269+
perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
270+
)
271+
ndims(a) == length(perm_codomain) + length(perm_domain) ||
272+
throw(ArgumentError("Invalid bipermutation"))
273+
a_perm_op = permutedimsop(op, a, perm_codomain, perm_domain)
274+
return matricize(style, a_perm_op, Val(length(perm_codomain)))
275+
end
276+
221277
# ==================================== unmatricize =======================================
222278
# This is the primary function that should be overloaded for new fusion styles.
223279
function unmatricize(

0 commit comments

Comments
 (0)