From a6a5eee3ba8577a6ddea75a57091c2e111de7bc4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 19:12:28 -0400 Subject: [PATCH 01/26] Bump version to v0.8.1 Co-Authored-By: Claude Sonnet 4.6 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c65eb0..a7089bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.8.0" +version = "0.8.1" authors = ["ITensor developers and contributors"] [workspace] From 352d05f90302f139c19aa1b119c18c5a145e7f98 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 19:57:45 -0400 Subject: [PATCH 02/26] Add matricizeop function --- src/matricize.jl | 31 +++++++++++++++++++++++++++++++ test/test_basics.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/matricize.jl b/src/matricize.jl index dc2c6de..cdb00b6 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -218,6 +218,37 @@ function matricize( return matricize(style, a, blocks(biperm_dest)...) end +# ==================================== matricizeop ======================================= + +""" + matricizeop(op, a, perm_codomain, perm_domain) + +Matricize `a` with element-wise operation `op` folded in. Returns a matrix representing +`op.(matricize(a, perm_codomain, perm_domain))`. + +Has "maybe alias" semantics: the result may be a view/wrapper aliasing `a` or a fresh +copy, depending on the fusion style and array type. The caller should treat the result +as read-only. +""" +function matricizeop(op, a::AbstractArray, perm_codomain, perm_domain) + return matricizeop(FusionStyle(a), op, a, perm_codomain, perm_domain) +end +function matricizeop( + style::FusionStyle, op, a::AbstractArray, perm_codomain, perm_domain + ) + return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...) +end +function matricizeop( + style::FusionStyle, op, a::AbstractArray, + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} + ) + m = matricize(style, a, perm_codomain, perm_domain) + return _apply_op(op, m) +end + +_apply_op(::typeof(identity), m::AbstractMatrix) = m +_apply_op(op, m::AbstractMatrix) = op.(m) + # ==================================== unmatricize ======================================= # This is the primary function that should be overloaded for new fusion styles. function unmatricize( diff --git a/test/test_basics.jl b/test/test_basics.jl index e48be7f..53f806c 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,3 +1,4 @@ +import TensorAlgebra using EllipsisNotation: var".." using StableRNGs: StableRNG using TensorAlgebra: BlockedTuple, ContractAlgorithm, bipermutedims, bipermutedims!, @@ -85,6 +86,40 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_fused ≈ ones(elt, 1, 1) end + @testset "matricizeop (eltype=$elt)" for elt in elts + rng = StableRNG(123) + a = randn(rng, elt, 2, 3, 4) + + # identity op: should match matricize exactly + m = TensorAlgebra.matricizeop(identity, a, (1,), (2, 3)) + m_ref = matricize(a, (1,), (2, 3)) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(identity, a, (3, 1), (2,)) + m_ref = matricize(a, (3, 1), (2,)) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(identity, a, (2, 3), (1,)) + m_ref = matricize(a, (2, 3), (1,)) + @test m ≈ m_ref + + # conj op (only for Complex eltypes) + if elt <: Complex + m = TensorAlgebra.matricizeop(conj, a, (1,), (2, 3)) + m_ref = conj.(matricize(a, (1,), (2, 3))) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(conj, a, (3, 1), (2,)) + m_ref = conj.(matricize(a, (3, 1), (2,))) + @test m ≈ m_ref + end + + # general op + m = TensorAlgebra.matricizeop(abs, a, (1,), (2, 3)) + m_ref = abs.(matricize(a, (1,), (2, 3))) + @test m ≈ m_ref + end + @testset "unmatricize (eltype=$elt)" for elt in elts a0 = randn(elt, 2, 3, 4, 5) axes0 = axes(a0) From 1a47fd8aa805f8cb1c4f1ad0aae617b7683daec2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 20:14:41 -0400 Subject: [PATCH 03/26] Add contractopadd!, refactor contractadd! to wrap it, update TO extension --- .../TensorAlgebraTensorOperationsExt.jl | 40 +++------ src/contract/contract.jl | 81 ++++++++++++------- src/contract/contract_matricize.jl | 22 ++--- test/test_basics.jl | 35 ++++++++ 4 files changed, 113 insertions(+), 65 deletions(-) diff --git a/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl index 8e3e942..778b7f5 100644 --- a/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl @@ -52,35 +52,22 @@ function TA.contract( end # in-place -function TA.contractadd!( +function TA.contractopadd!( algorithm::TensorOperationsAlgorithm, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) permblocks1 = Tuple.((perm1_codomain, perm1_domain)) permblocks2 = Tuple.((perm2_codomain, perm2_domain)) permblocks_dest = Tuple.((perm_dest_codomain, perm_dest_domain)) - conj1, conj2 = false, false - return TO.tensorcontract!( - a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2, - permblocks_dest, α, β, algorithm.backend - ) -end - -function TA.contractadd!( - algorithm::TensorOperationsAlgorithm, - a_dest::AbstractArray, labels_dest, - a1::AbstractArray, labels1, - a2::AbstractArray, labels2, - α::Number, β::Number - ) - permblocks1, permblocks2, permblocks_dest = - TO.contract_indices(labels1, labels2, labels_dest) - conj1, conj2 = false, false + conj1 = op1 === conj + conj2 = op2 === conj + a1′ = (op1 === identity || op1 === conj) ? a1 : op1.(a1) + a2′ = (op2 === identity || op2 === conj) ? a2 : op2.(a2) return TO.tensorcontract!( - a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2, + a_dest, a1′, permblocks1, conj1, a2′, permblocks2, conj2, permblocks_dest, α, β, algorithm.backend ) end @@ -96,14 +83,13 @@ function TO.tensorcontract!( backend::TA.ContractAlgorithm, allocator ) - # TODO: FIXME: Use `conjed` to do the conjugation lazily. - a1′ = conj1 ? conj(a1) : a1 - a2′ = conj2 ? conj(a2) : a2 - return TA.contractadd!( + op1 = conj1 ? conj : identity + op2 = conj2 ? conj : identity + return TA.contractopadd!( backend, a_dest, permblocks_dest..., - a1′, permblocks1..., - a2′, permblocks2..., + op1, a1, permblocks1..., + op2, a2, permblocks2..., α, β ) end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 87231b9..6fe5815 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -87,19 +87,48 @@ function contractadd!( α::Number, β::Number; kwargs... ) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contractadd!( - a_dest, blocks(biperm_dest)..., - a1, blocks(biperm1)..., - a2, blocks(biperm2)..., - α, β; kwargs... + return contractopadd!( + a_dest, labels_dest, identity, a1, labels1, identity, a2, labels2, α, β; kwargs... ) end +# contractadd! (bipartitioned permutations) function contractadd!( a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, a1::AbstractArray, perm1_codomain, perm1_domain, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number; + kwargs... + ) + return contractopadd!( + a_dest, perm_dest_codomain, perm_dest_domain, + identity, a1, perm1_codomain, perm1_domain, + identity, a2, perm2_codomain, perm2_domain, + α, β; kwargs... + ) +end + +# contractopadd! (labels) +function contractopadd!( + a_dest::AbstractArray, labels_dest, + op1, a1::AbstractArray, labels1, + op2, a2::AbstractArray, labels2, + α::Number, β::Number; + kwargs... + ) + biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) + return contractopadd!( + a_dest, blocks(biperm_dest)..., + op1, a1, blocks(biperm1)..., + op2, a2, blocks(biperm2)..., + α, β; kwargs... + ) +end +# contractopadd! (bipartitioned permutations, algorithm selection) +function contractopadd!( + a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, + α::Number, β::Number; alg = DefaultContractAlgorithm(), kwargs... ) check_input( @@ -109,38 +138,38 @@ function contractadd!( a2, perm2_codomain, perm2_domain ) algorithm = select_contract_algorithm(alg, a1, a2; kwargs...) - return contractadd!( + return contractopadd!( algorithm, a_dest, perm_dest_codomain, perm_dest_domain, - a1, perm1_codomain, perm1_domain, - a2, perm2_codomain, perm2_domain, + op1, a1, perm1_codomain, perm1_domain, + op2, a2, perm2_codomain, perm2_domain, α, β ) end -# contractadd! (dispatched on the algorithm, bipartitioned permutations) +# contractopadd! (dispatched on the algorithm, bipartitioned permutations) # Required interface if not using matricized contraction -function contractadd!( +function contractopadd!( algorithm::ContractAlgorithm, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) return throw( MethodError( - contractadd!, + contractopadd!, ( algorithm, a_dest, perm_dest_codomain, perm_dest_domain, - a1, perm1_codomain, perm1_domain, - a2, perm2_codomain, perm2_domain, + op1, a1, perm1_codomain, perm1_domain, + op2, a2, perm2_codomain, perm2_domain, α, β, ) ) ) end -# BlockPermutation versions of contract[add][!] +# BlockPermutation versions of contract[opadd][!] function contract( a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, biperm2::AbstractBlockPermutation{2}; @@ -187,18 +216,16 @@ function contractadd!( α, β; kwargs... ) end -function contractadd!( - algorithm::ContractAlgorithm, +function contractopadd!( a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, - a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, - α::Number, β::Number + op1, a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + op2, a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, + α::Number, β::Number; kwargs... ) - return contractadd!( - algorithm, + return contractopadd!( a_dest, blocks(biperm_dest)..., - a1, blocks(biperm1)..., - a2, blocks(biperm2)..., - α, β + op1, a1, blocks(biperm1)..., + op2, a2, blocks(biperm2)..., + α, β; kwargs... ) end diff --git a/src/contract/contract_matricize.jl b/src/contract/contract_matricize.jl index 701dd4f..ed31bcf 100644 --- a/src/contract/contract_matricize.jl +++ b/src/contract/contract_matricize.jl @@ -1,26 +1,26 @@ using LinearAlgebra: mul! -function contractadd!( +function contractopadd!( algorithm::Matricize, a_dest::AbstractArray, biperm_dest_codomain, biperm_dest_domain, - a1::AbstractArray, biperm1_codomain, biperm1_domain, - a2::AbstractArray, biperm2_codomain, biperm2_domain, + op1, a1::AbstractArray, biperm1_codomain, biperm1_domain, + op2, a2::AbstractArray, biperm2_codomain, biperm2_domain, α::Number, β::Number ) - return contractadd!_matricize( + return contractopadd!_matricize( algorithm, a_dest, biperm_dest_codomain, biperm_dest_domain, - a1, biperm1_codomain, biperm1_domain, - a2, biperm2_codomain, biperm2_domain, + op1, a1, biperm1_codomain, biperm1_domain, + op2, a2, biperm2_codomain, biperm2_domain, α, β ) end -function contractadd!_matricize( +function contractopadd!_matricize( algorithm::Matricize, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) perm_dest = (perm_dest_codomain..., perm_dest_domain...) @@ -32,8 +32,8 @@ function contractadd!_matricize( a1, perm1_codomain, perm1_domain, a2, perm2_codomain, perm2_domain ) - a1_mat = matricize(algorithm.fusion_style, a1, perm1_codomain, perm1_domain) - a2_mat = matricize(algorithm.fusion_style, a2, perm2_codomain, perm2_domain) + a1_mat = matricizeop(algorithm.fusion_style, op1, a1, perm1_codomain, perm1_domain) + a2_mat = matricizeop(algorithm.fusion_style, op2, a2, perm2_codomain, perm2_domain) a_dest_mat = a1_mat * a2_mat unmatricizeadd!( algorithm.fusion_style, a_dest, a_dest_mat, invperm_codomain, invperm_domain, α, β diff --git a/test/test_basics.jl b/test/test_basics.jl index 53f806c..298ef37 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -283,6 +283,41 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) ) end + @testset "contractopadd! (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + elt_dest = promote_type(elt1, elt2) + dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) + labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) + rng = StableRNG(123) + for (d1s, d2s, d_dests) in ( + ((1, 2), (2, 3), (1, 3)), + ((1, 2), (2, 3), (3, 1)), + ((1, 2, 3), (2, 3, 4), (1, 4)), + ((3, 2, 1), (4, 2, 3), (4, 1)), + ((1, 2, 3), (3, 4), (2, 4, 1)), + ) + a1 = randn(rng, elt1, map(i -> dims[i], d1s)) + labels1 = map(i -> labels[i], d1s) + a2 = randn(rng, elt2, map(i -> dims[i], d2s)) + labels2 = map(i -> labels[i], d2s) + labels_dest = map(i -> labels[i], d_dests) + + α = elt_dest(1.2) + β = elt_dest(2.4) + a_dest_init = randn(rng, elt_dest, map(i -> dims[i], d_dests)) + + # identity ops should match contractadd! + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + identity, a1, labels1, + identity, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, a1, labels1, a2, labels2, α, β) + @test a_dest ≈ a_dest_ref + end + end @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts From ce17cb0278c3d53a915dee85ffb6523baf7367d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 20:51:55 -0400 Subject: [PATCH 04/26] Add contractopadd! tests with conj op --- test/test_basics.jl | 74 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/test/test_basics.jl b/test/test_basics.jl index 298ef37..96ff68b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -316,6 +316,80 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_dest_ref = copy(a_dest_init) contractadd!(a_dest_ref, labels_dest, a1, labels1, a2, labels2, α, β) @test a_dest ≈ a_dest_ref + + # conj on first input (only for Complex elt1) + if elt1 <: Complex + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, conj.(a1), labels1, a2, labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + end + + # conj on second input (only for Complex elt2) + if elt2 <: Complex + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, a1, labels1, conj.(a2), labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + end + + # conj on both inputs (only for Complex elt1 and elt2) + if elt1 <: Complex && elt2 <: Complex + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!( + a_dest_ref, labels_dest, conj.(a1), labels1, conj.(a2), labels2, α, β + ) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + end end end @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, From 57fc068aa216a569c6e551263f98307f407c1309 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 21:09:44 -0400 Subject: [PATCH 05/26] Export contractopadd! --- src/TensorAlgebra.jl | 5 +++-- test/test_basics.jl | 18 +++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index a1c966a..bf91c28 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,7 +1,8 @@ module TensorAlgebra -export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, - lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals +export contract, contract!, contractopadd!, eigen, eigvals, factorize, left_null, + left_orth, left_polar, lq, qr, right_null, right_orth, right_polar, orth, polar, svd, + svdvals include("MatrixAlgebra.jl") include("blockedtuple.jl") diff --git a/test/test_basics.jl b/test/test_basics.jl index 96ff68b..24c46da 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,8 +2,8 @@ import TensorAlgebra using EllipsisNotation: var".." using StableRNGs: StableRNG using TensorAlgebra: BlockedTuple, ContractAlgorithm, bipermutedims, bipermutedims!, - blockedpermvcat, contract, contract!, contractadd!, length_codomain, length_domain, - matricize, tuplemortar, unmatricize, unmatricize! + blockedpermvcat, contract, contract!, contractadd!, contractopadd!, length_codomain, + length_domain, matricize, tuplemortar, unmatricize, unmatricize! using TensorOperations: TensorOperations using Test: @test, @test_broken, @test_throws, @testset @@ -307,7 +307,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # identity ops should match contractadd! a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest, labels_dest, identity, a1, labels1, identity, a2, labels2, @@ -320,7 +320,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on first input (only for Complex elt1) if elt1 <: Complex a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest, labels_dest, conj, a1, labels1, identity, a2, labels2, @@ -332,7 +332,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest_to, labels_dest, conj, a1, labels1, identity, a2, labels2, @@ -344,7 +344,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on second input (only for Complex elt2) if elt2 <: Complex a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest, labels_dest, identity, a1, labels1, conj, a2, labels2, @@ -356,7 +356,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest_to, labels_dest, identity, a1, labels1, conj, a2, labels2, @@ -368,7 +368,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on both inputs (only for Complex elt1 and elt2) if elt1 <: Complex && elt2 <: Complex a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest, labels_dest, conj, a1, labels1, conj, a2, labels2, @@ -382,7 +382,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( + contractopadd!( a_dest_to, labels_dest, conj, a1, labels1, conj, a2, labels2, From 41814e46aae7ae2f6214b383807234f38890f77d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 23:46:15 -0400 Subject: [PATCH 06/26] Make contractopadd! and matricizeop public (not exported) --- src/TensorAlgebra.jl | 7 ++++--- test/test_basics.jl | 18 +++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index bf91c28..166f782 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,8 +1,9 @@ module TensorAlgebra -export contract, contract!, contractopadd!, eigen, eigvals, factorize, left_null, - left_orth, left_polar, lq, qr, right_null, right_orth, right_polar, orth, polar, svd, - svdvals +export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, + lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals + +public contractopadd!, matricizeop include("MatrixAlgebra.jl") include("blockedtuple.jl") diff --git a/test/test_basics.jl b/test/test_basics.jl index 24c46da..96ff68b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,8 +2,8 @@ import TensorAlgebra using EllipsisNotation: var".." using StableRNGs: StableRNG using TensorAlgebra: BlockedTuple, ContractAlgorithm, bipermutedims, bipermutedims!, - blockedpermvcat, contract, contract!, contractadd!, contractopadd!, length_codomain, - length_domain, matricize, tuplemortar, unmatricize, unmatricize! + blockedpermvcat, contract, contract!, contractadd!, length_codomain, length_domain, + matricize, tuplemortar, unmatricize, unmatricize! using TensorOperations: TensorOperations using Test: @test, @test_broken, @test_throws, @testset @@ -307,7 +307,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # identity ops should match contractadd! a_dest = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest, labels_dest, identity, a1, labels1, identity, a2, labels2, @@ -320,7 +320,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on first input (only for Complex elt1) if elt1 <: Complex a_dest = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest, labels_dest, conj, a1, labels1, identity, a2, labels2, @@ -332,7 +332,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest_to, labels_dest, conj, a1, labels1, identity, a2, labels2, @@ -344,7 +344,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on second input (only for Complex elt2) if elt2 <: Complex a_dest = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest, labels_dest, identity, a1, labels1, conj, a2, labels2, @@ -356,7 +356,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest_to, labels_dest, identity, a1, labels1, conj, a2, labels2, @@ -368,7 +368,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # conj on both inputs (only for Complex elt1 and elt2) if elt1 <: Complex && elt2 <: Complex a_dest = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest, labels_dest, conj, a1, labels1, conj, a2, labels2, @@ -382,7 +382,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # compare against TensorOperations backend a_dest_to = copy(a_dest_init) - contractopadd!( + TensorAlgebra.contractopadd!( a_dest_to, labels_dest, conj, a1, labels1, conj, a2, labels2, From 0c379606c8e757652b65c246e17e8a9683eaacaa Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 Apr 2026 23:57:44 -0400 Subject: [PATCH 07/26] Add contractopadd! and matricizeop to expected public names --- test/test_exports.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_exports.jl b/test/test_exports.jl index f7f9bda..0068b43 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -6,6 +6,7 @@ using Test: @test, @testset :TensorAlgebra, :contract, :contract!, + :contractopadd!, :eigen, :eigvals, :factorize, @@ -13,6 +14,7 @@ using Test: @test, @testset :left_orth, :left_polar, :lq, + :matricizeop, :orth, :polar, :qr, From 0d4d2e2a76246f817c5de20cc0ceb62c79ece05b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 00:18:22 -0400 Subject: [PATCH 08/26] Use version-guarded public for Julia 1.10 compat --- src/TensorAlgebra.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 166f782..5202364 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -3,7 +3,9 @@ module TensorAlgebra export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals -public contractopadd!, matricizeop +if VERSION >= v"1.11.0-DEV.469" + eval(Meta.parse("public contractopadd!, matricizeop")) +end include("MatrixAlgebra.jl") include("blockedtuple.jl") From d37aebf4134d8671f75b8214c01924219770aeed Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 19:58:46 -0400 Subject: [PATCH 09/26] Make export test version-aware for public names --- test/test_exports.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_exports.jl b/test/test_exports.jl index 0068b43..0fb0d9b 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -6,7 +6,6 @@ using Test: @test, @testset :TensorAlgebra, :contract, :contract!, - :contractopadd!, :eigen, :eigvals, :factorize, @@ -14,7 +13,6 @@ using Test: @test, @testset :left_orth, :left_polar, :lq, - :matricizeop, :orth, :polar, :qr, @@ -24,6 +22,10 @@ using Test: @test, @testset :svd, :svdvals, ] + # `public` (Julia 1.11+) adds names to `names()`; include them on 1.11+. + if VERSION >= v"1.11.0-DEV.469" + append!(exports, [:contractopadd!, :matricizeop]) + end @test issetequal(names(TensorAlgebra), exports) exports = [ From 3c7f8f92c8df9abb7cee1dd7a4255180846af22d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 20:11:26 -0400 Subject: [PATCH 10/26] Invert matricize/matricizeop: matricizeop is primary, matricize wraps it --- src/matricize.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index cdb00b6..b76c805 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -165,15 +165,14 @@ function matricize( ) return matricize(FusionStyle(a), a, perm_codomain, perm_domain) end -# This is a more advanced version to overload where the permutation is actually performed. +# Thin wrapper around `matricizeop` with identity op — the actual matricization logic +# (and the fusion-style overload point for folding ops into matricization) lives in +# `matricizeop`. function matricize( style::FusionStyle, a::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} ) - ndims(a) == length(perm_codomain) + length(perm_domain) || - throw(ArgumentError("Invalid bipermutation")) - a_perm = bipermutedims(a, perm_codomain, perm_domain) - return matricize(style, a_perm, Val(length(perm_codomain))) + return matricizeop(style, identity, a, perm_codomain, perm_domain) end # Process inputs such as `EllipsisNotation.Ellipsis`. @@ -238,11 +237,17 @@ function matricizeop( ) return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...) end +# This is the primary function that should be overloaded for new fusion styles to fold +# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy +# wrappers like StridedView with op metadata for zero-copy). function matricizeop( style::FusionStyle, op, a::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} ) - m = matricize(style, a, perm_codomain, perm_domain) + ndims(a) == length(perm_codomain) + length(perm_domain) || + throw(ArgumentError("Invalid bipermutation")) + a_perm = bipermutedims(a, perm_codomain, perm_domain) + m = matricize(style, a_perm, Val(length(perm_codomain))) return _apply_op(op, m) end From 32936fcb47b077158e775fd84b1bbe19c668a0d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 21:35:24 -0400 Subject: [PATCH 11/26] Route matricizeop through new bipermutedimsop/bipermutedimsopadd! --- src/matricize.jl | 47 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index b76c805..ab976a8 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -123,13 +123,50 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val) return matricize_axes(FusionStyle(a), a, ndims_codomain) end +# `bipermutedimsopadd!` / `bipermutedimsop` — bipermutation versions of +# `permutedimsopadd!` / `permuteddims` with an element-wise op folded in. +# +# These are intended to become the primary overload points for downstream array +# types that want to fold ops into a bipartitioned permutation copy (e.g., fuse +# `conj` into the copy, or use lazy wrappers like `StridedView` with op metadata). +# For now, `bipermutedimsopadd!` delegates to the flat-permutation `permutedimsopadd!`. +# In a future PR, the dependency will flip so that `permutedimsopadd!` wraps +# `bipermutedimsopadd!`. + +""" + bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) + +Like `permutedimsopadd!`, but takes a bipartitioned permutation +`(perm_codomain, perm_domain)`. +""" +function bipermutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, + perm_codomain, perm_domain, + α::Number, β::Number + ) + return permutedimsopadd!(dest, op, src, (perm_codomain..., perm_domain...), α, β) +end + +""" + bipermutedimsop(op, src, perm_codomain, perm_domain) + +Non-mutating version of `bipermutedimsopadd!`: returns +`op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias" +semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. +""" +function bipermutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) + perm = (perm_codomain..., perm_domain...) + dest = similar(src, map(i -> size(src, i), perm)) + return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) +end + # Inner version takes a list of sub-permutations, overload this one if needed. # TODO: Remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. # TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`. # Keeping it here for backwards compatibility. function bipermutedims(a::AbstractArray, perm1, perm2) - return _permutedims(a, (perm1..., perm2...)) + return bipermutedimsop(identity, a, perm1, perm2) end function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) return _permutedims!(a_dest, a_src, (perm1..., perm2...)) @@ -246,14 +283,10 @@ function matricizeop( ) ndims(a) == length(perm_codomain) + length(perm_domain) || throw(ArgumentError("Invalid bipermutation")) - a_perm = bipermutedims(a, perm_codomain, perm_domain) - m = matricize(style, a_perm, Val(length(perm_codomain))) - return _apply_op(op, m) + a_perm_op = bipermutedimsop(op, a, perm_codomain, perm_domain) + return matricize(style, a_perm_op, Val(length(perm_codomain))) end -_apply_op(::typeof(identity), m::AbstractMatrix) = m -_apply_op(op, m::AbstractMatrix) = op.(m) - # ==================================== unmatricize ======================================= # This is the primary function that should be overloaded for new fusion styles. function unmatricize( From 4aaf3f820737a446068696c4bc7402230fcedfd5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 21:42:20 -0400 Subject: [PATCH 12/26] Rename bipermutedimsopadd! to permutedimsopadd! biperm overload --- src/matricize.jl | 28 ++-------------------------- src/permutedimsadd.jl | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index ab976a8..c4f4484 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -123,41 +123,17 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val) return matricize_axes(FusionStyle(a), a, ndims_codomain) end -# `bipermutedimsopadd!` / `bipermutedimsop` — bipermutation versions of -# `permutedimsopadd!` / `permuteddims` with an element-wise op folded in. -# -# These are intended to become the primary overload points for downstream array -# types that want to fold ops into a bipartitioned permutation copy (e.g., fuse -# `conj` into the copy, or use lazy wrappers like `StridedView` with op metadata). -# For now, `bipermutedimsopadd!` delegates to the flat-permutation `permutedimsopadd!`. -# In a future PR, the dependency will flip so that `permutedimsopadd!` wraps -# `bipermutedimsopadd!`. - -""" - bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) - -Like `permutedimsopadd!`, but takes a bipartitioned permutation -`(perm_codomain, perm_domain)`. -""" -function bipermutedimsopadd!( - dest::AbstractArray, op, src::AbstractArray, - perm_codomain, perm_domain, - α::Number, β::Number - ) - return permutedimsopadd!(dest, op, src, (perm_codomain..., perm_domain...), α, β) -end - """ bipermutedimsop(op, src, perm_codomain, perm_domain) -Non-mutating version of `bipermutedimsopadd!`: returns +Non-mutating version of bipermutation `permutedimsopadd!`: returns `op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias" semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. """ function bipermutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) perm = (perm_codomain..., perm_domain...) dest = similar(src, map(i -> size(src, i), perm)) - return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) + return permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) end # Inner version takes a list of sub-permutations, overload this one if needed. diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index dac9490..8cf7ac3 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -58,6 +58,25 @@ function permutedimsopadd!( return dest end +# Bipartitioned permutation overload. Intended to become a primary overload point +# for downstream array types that want to fold ops into a bipartitioned permutation +# copy (e.g., fuse `conj` into the copy, or use lazy wrappers like `StridedView` +# with op metadata). For now it delegates to the flat-permutation version by +# concatenating the perms; in a future PR the dependency will flip. + +""" + permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) + +Like `permutedimsopadd!`, but takes a bipartitioned permutation as two tuples. +""" +function permutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, + perm_codomain, perm_domain, + α::Number, β::Number + ) + return permutedimsopadd!(dest, op, src, (perm_codomain..., perm_domain...), α, β) +end + # ---------------------------------------------------------------------------- # # Convenience functions that lower to permutedimsopadd! # ---------------------------------------------------------------------------- # From d77bb13ae3ceff3b0d4100bc4219bed9475d8b98 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 21:47:29 -0400 Subject: [PATCH 13/26] Unify naming: permutedimsopadd! biperm overload, permutedimsop, bipermutedims! via permutedimsopadd! --- src/matricize.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index c4f4484..51b9661 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -124,13 +124,13 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val) end """ - bipermutedimsop(op, src, perm_codomain, perm_domain) + permutedimsop(op, src, perm_codomain, perm_domain) Non-mutating version of bipermutation `permutedimsopadd!`: returns `op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias" semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. """ -function bipermutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) +function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) perm = (perm_codomain..., perm_domain...) dest = similar(src, map(i -> size(src, i), perm)) return permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) @@ -142,10 +142,10 @@ end # TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`. # Keeping it here for backwards compatibility. function bipermutedims(a::AbstractArray, perm1, perm2) - return bipermutedimsop(identity, a, perm1, perm2) + return permutedimsop(identity, a, perm1, perm2) end function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) - return _permutedims!(a_dest, a_src, (perm1..., perm2...)) + return permutedimsopadd!(a_dest, identity, a_src, perm1, perm2, true, false) end function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) return bipermutedims(a, blocks(biperm)...) @@ -259,7 +259,7 @@ function matricizeop( ) ndims(a) == length(perm_codomain) + length(perm_domain) || throw(ArgumentError("Invalid bipermutation")) - a_perm_op = bipermutedimsop(op, a, perm_codomain, perm_domain) + a_perm_op = permutedimsop(op, a, perm_codomain, perm_domain) return matricize(style, a_perm_op, Val(length(perm_codomain))) end From 299f18108424a16f25d4189f06f68f7ec72de5a9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Apr 2026 23:06:36 -0400 Subject: [PATCH 14/26] Flip permutedimsopadd!: biperm is primary, flat wraps it. Add allocate_output for permutedimsop with bipartitioned axes. --- src/matricize.jl | 16 ++++++++++++++-- src/permutedimsadd.jl | 42 ++++++++++++++++++++++-------------------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 51b9661..36b2c67 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -123,6 +123,12 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val) return matricize_axes(FusionStyle(a), a, ndims_codomain) end +# Default similar with bipartitioned axes: flatten to a plain tuple of axes. +# Downstream types (e.g., FusionTensor) can override to preserve bipartition. +function Base.similar(a::AbstractArray, T::Type, axes::AbstractBlockTuple{2}) + return similar(a, T, Tuple(axes)) +end + """ permutedimsop(op, src, perm_codomain, perm_domain) @@ -131,11 +137,17 @@ Non-mutating version of bipermutation `permutedimsopadd!`: returns semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. """ function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) - perm = (perm_codomain..., perm_domain...) - dest = similar(src, map(i -> size(src, i), perm)) + dest = allocate_output(permutedimsop, op, src, perm_codomain, perm_domain) return permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) end +function allocate_output(::typeof(permutedimsop), op, src::AbstractArray, perm_co, perm_do) + T = Base.promote_op(op, eltype(src)) + axes_co = map(i -> axes(src, i), perm_co) + axes_do = map(i -> axes(src, i), perm_do) + return similar(src, T, tuplemortar((axes_co, axes_do))) +end + # Inner version takes a list of sub-permutations, overload this one if needed. # TODO: Remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 8cf7ac3..21e1528 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -14,20 +14,27 @@ end # ---------------------------------------------------------------------------- # """ - permutedimsopadd!(dest, op, src, perm, α, β) + permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) -`dest = β * dest + α * permutedims(op.(src), perm)`. +`dest = β * dest + α * permutedims(op.(src), (perm_codomain..., perm_domain...))`. -This is the single materialization primitive for `LinearBroadcasted` types. -Downstream array types should implement this function. The `op` is an element-wise -linear map (e.g., `identity`, `conj`, `adjoint`, `transpose`, `Float32`). +This is the primary overload point for downstream array types that want to +implement op-aware bipartitioned permutation + accumulation (e.g., fuse `conj` +into the copy, or use lazy wrappers like `StridedView` with op metadata). -The default implementation applies `op` element-wise, permutes, then accumulates -via broadcasting with Strided.jl optimization when possible. +The `op` is an element-wise linear map (e.g., `identity`, `conj`). + +The default implementation flattens the bipartitioned permutation, applies `op` +element-wise, permutes, then accumulates via broadcasting with Strided.jl +optimization when possible. """ function permutedimsopadd!( - dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number + dest::AbstractArray, op, src::AbstractArray, + perm_codomain, perm_domain, + α::Number, β::Number ) + perm = (perm_codomain..., perm_domain...) + # TODO: Remove this 0-dimensional special case once GradedArray is its own type # (not an alias for BlockSparseArray), so the GradedArray permutedimsopadd! overload # catches the 0-dimensional contraction result. @@ -58,23 +65,18 @@ function permutedimsopadd!( return dest end -# Bipartitioned permutation overload. Intended to become a primary overload point -# for downstream array types that want to fold ops into a bipartitioned permutation -# copy (e.g., fuse `conj` into the copy, or use lazy wrappers like `StridedView` -# with op metadata). For now it delegates to the flat-permutation version by -# concatenating the perms; in a future PR the dependency will flip. - """ - permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) + permutedimsopadd!(dest, op, src, perm, α, β) + +`dest = β * dest + α * permutedims(op.(src), perm)`. -Like `permutedimsopadd!`, but takes a bipartitioned permutation as two tuples. +Flat-permutation convenience overload. Forwards to the bipartitioned version +with `perm_domain = ()`. """ function permutedimsopadd!( - dest::AbstractArray, op, src::AbstractArray, - perm_codomain, perm_domain, - α::Number, β::Number + dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number ) - return permutedimsopadd!(dest, op, src, (perm_codomain..., perm_domain...), α, β) + return permutedimsopadd!(dest, op, src, perm, (), α, β) end # ---------------------------------------------------------------------------- # From 57eeb02409ade218fd5623c72e456b2f72ce80fa Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 08:56:54 -0400 Subject: [PATCH 15/26] Narrow Base.similar piracy to BlockedTuple{2} --- src/matricize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matricize.jl b/src/matricize.jl index 36b2c67..8ee89db 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -125,7 +125,7 @@ end # Default similar with bipartitioned axes: flatten to a plain tuple of axes. # Downstream types (e.g., FusionTensor) can override to preserve bipartition. -function Base.similar(a::AbstractArray, T::Type, axes::AbstractBlockTuple{2}) +function Base.similar(a::AbstractArray, T::Type, axes::BlockedTuple{2}) return similar(a, T, Tuple(axes)) end From 1cf15f169a1b78389ae9efd8949286d63ef96479 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 10:23:11 -0400 Subject: [PATCH 16/26] Revert to bipermutedimsopadd! name for clarity --- src/matricize.jl | 6 +++--- src/permutedimsadd.jl | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 8ee89db..8ce0097 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -132,13 +132,13 @@ end """ permutedimsop(op, src, perm_codomain, perm_domain) -Non-mutating version of bipermutation `permutedimsopadd!`: returns +Non-mutating version of `bipermutedimsopadd!`: returns `op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias" semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. """ function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) dest = allocate_output(permutedimsop, op, src, perm_codomain, perm_domain) - return permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) + return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) end function allocate_output(::typeof(permutedimsop), op, src::AbstractArray, perm_co, perm_do) @@ -157,7 +157,7 @@ function bipermutedims(a::AbstractArray, perm1, perm2) return permutedimsop(identity, a, perm1, perm2) end function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) - return permutedimsopadd!(a_dest, identity, a_src, perm1, perm2, true, false) + return bipermutedimsopadd!(a_dest, identity, a_src, perm1, perm2, true, false) end function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) return bipermutedims(a, blocks(biperm)...) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 21e1528..5e2c304 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -10,11 +10,11 @@ function maybestrided(as::AbstractArray...) end # ---------------------------------------------------------------------------- # -# permutedimsopadd! — the single materialization primitive +# bipermutedimsopadd! — the primary materialization primitive # ---------------------------------------------------------------------------- # """ - permutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) + bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) `dest = β * dest + α * permutedims(op.(src), (perm_codomain..., perm_domain...))`. @@ -28,7 +28,7 @@ The default implementation flattens the bipartitioned permutation, applies `op` element-wise, permutes, then accumulates via broadcasting with Strided.jl optimization when possible. """ -function permutedimsopadd!( +function bipermutedimsopadd!( dest::AbstractArray, op, src::AbstractArray, perm_codomain, perm_domain, α::Number, β::Number @@ -36,8 +36,8 @@ function permutedimsopadd!( perm = (perm_codomain..., perm_domain...) # TODO: Remove this 0-dimensional special case once GradedArray is its own type - # (not an alias for BlockSparseArray), so the GradedArray permutedimsopadd! overload - # catches the 0-dimensional contraction result. + # (not an alias for BlockSparseArray), so the GradedArray overload catches the + # 0-dimensional contraction result. if iszero(ndims(dest)) dest[] = β * dest[] + α * op(src[]) return dest @@ -65,18 +65,24 @@ function permutedimsopadd!( return dest end +# ---------------------------------------------------------------------------- # +# permutedimsopadd! — flat-permutation interface +# ---------------------------------------------------------------------------- # + """ permutedimsopadd!(dest, op, src, perm, α, β) `dest = β * dest + α * permutedims(op.(src), perm)`. -Flat-permutation convenience overload. Forwards to the bipartitioned version +This is the single materialization primitive for `LinearBroadcasted` types. +Downstream array types should implement `bipermutedimsopadd!` for the +bipartitioned permutation version; this flat-permutation overload forwards to it with `perm_domain = ()`. """ function permutedimsopadd!( dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number ) - return permutedimsopadd!(dest, op, src, perm, (), α, β) + return bipermutedimsopadd!(dest, op, src, perm, (), α, β) end # ---------------------------------------------------------------------------- # From cd9fb877b55490ea536cf0e6be0b1e04b4c5ae77 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 10:45:00 -0400 Subject: [PATCH 17/26] Bump minor version to v0.9.0 (breaking: bipermutedimsopadd! is now primary) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a7089bb..67f8556 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.8.1" +version = "0.9.0" authors = ["ITensor developers and contributors"] [workspace] From f262919104c3e0b2c7dbe6b426dfad5ff4a71317 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 11:28:52 -0400 Subject: [PATCH 18/26] Bump TensorAlgebra compat in test/ to 0.9 --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 5c0c342..d681c86 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From ea867afe0a313cbeccb961bf6af0c24b4905b8bd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 11:30:45 -0400 Subject: [PATCH 19/26] Revert to v0.8.1 (breaking only for GradedArrays, not ecosystem-wide) --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 67f8556..a7089bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.9.0" +version = "0.8.1" authors = ["ITensor developers and contributors"] [workspace] diff --git a/test/Project.toml b/test/Project.toml index d681c86..3f95e41 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.9" +TensorAlgebra = "0.8.1" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From 296ca292e87333ce0c5c32d470017745b63dfbfd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 11:52:16 -0400 Subject: [PATCH 20/26] Add generic check_input for bipermutedimsopadd!, fix docstring attachment --- src/permutedimsadd.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 5e2c304..9e46b8d 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -13,6 +13,18 @@ end # bipermutedimsopadd! — the primary materialization primitive # ---------------------------------------------------------------------------- # +function check_input( + ::typeof(bipermutedimsopadd!), dest::AbstractArray, src::AbstractArray, + perm_codomain, perm_domain + ) + perm = (perm_codomain..., perm_domain...) + ndims(dest) == length(perm) || + throw(DimensionMismatch("destination ndims does not match permutation length")) + axes(dest) == ntuple(d -> axes(src, perm[d]), ndims(dest)) || + throw(DimensionMismatch("destination axes do not match permuted source axes")) + return nothing +end + """ bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) @@ -34,6 +46,7 @@ function bipermutedimsopadd!( α::Number, β::Number ) perm = (perm_codomain..., perm_domain...) + check_input(bipermutedimsopadd!, dest, src, perm_codomain, perm_domain) # TODO: Remove this 0-dimensional special case once GradedArray is its own type # (not an alias for BlockSparseArray), so the GradedArray overload catches the From a3ab7c75bd75efeae53200a515cb345c567b2587 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 11:59:04 -0400 Subject: [PATCH 21/26] Add forward declaration for bipermutedimsopadd! --- src/permutedimsadd.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index 9e46b8d..eb21f0a 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -13,6 +13,8 @@ end # bipermutedimsopadd! — the primary materialization primitive # ---------------------------------------------------------------------------- # +function bipermutedimsopadd! end + function check_input( ::typeof(bipermutedimsopadd!), dest::AbstractArray, src::AbstractArray, perm_codomain, perm_domain From 092d71e8a43ee5e6d3f7f2ef2ec0c39041399cde Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 14:05:16 -0400 Subject: [PATCH 22/26] Remove unnecessary Complex guards from conj tests --- test/test_basics.jl | 158 +++++++++++++++++++++----------------------- 1 file changed, 75 insertions(+), 83 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 96ff68b..ecc27b7 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -103,16 +103,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) m_ref = matricize(a, (2, 3), (1,)) @test m ≈ m_ref - # conj op (only for Complex eltypes) - if elt <: Complex - m = TensorAlgebra.matricizeop(conj, a, (1,), (2, 3)) - m_ref = conj.(matricize(a, (1,), (2, 3))) - @test m ≈ m_ref - - m = TensorAlgebra.matricizeop(conj, a, (3, 1), (2,)) - m_ref = conj.(matricize(a, (3, 1), (2,))) - @test m ≈ m_ref - end + # conj op + m = TensorAlgebra.matricizeop(conj, a, (1,), (2, 3)) + m_ref = conj.(matricize(a, (1,), (2, 3))) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(conj, a, (3, 1), (2,)) + m_ref = conj.(matricize(a, (3, 1), (2,))) + @test m ≈ m_ref # general op m = TensorAlgebra.matricizeop(abs, a, (1,), (2, 3)) @@ -317,79 +315,73 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) contractadd!(a_dest_ref, labels_dest, a1, labels1, a2, labels2, α, β) @test a_dest ≈ a_dest_ref - # conj on first input (only for Complex elt1) - if elt1 <: Complex - a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest, labels_dest, - conj, a1, labels1, - identity, a2, labels2, - α, β - ) - a_dest_ref = copy(a_dest_init) - contractadd!(a_dest_ref, labels_dest, conj.(a1), labels1, a2, labels2, α, β) - @test a_dest ≈ a_dest_ref - - # compare against TensorOperations backend - a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest_to, labels_dest, - conj, a1, labels1, - identity, a2, labels2, - α, β; alg = alg_tensoroperations - ) - @test a_dest ≈ a_dest_to - end - - # conj on second input (only for Complex elt2) - if elt2 <: Complex - a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest, labels_dest, - identity, a1, labels1, - conj, a2, labels2, - α, β - ) - a_dest_ref = copy(a_dest_init) - contractadd!(a_dest_ref, labels_dest, a1, labels1, conj.(a2), labels2, α, β) - @test a_dest ≈ a_dest_ref - - # compare against TensorOperations backend - a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest_to, labels_dest, - identity, a1, labels1, - conj, a2, labels2, - α, β; alg = alg_tensoroperations - ) - @test a_dest ≈ a_dest_to - end - - # conj on both inputs (only for Complex elt1 and elt2) - if elt1 <: Complex && elt2 <: Complex - a_dest = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest, labels_dest, - conj, a1, labels1, - conj, a2, labels2, - α, β - ) - a_dest_ref = copy(a_dest_init) - contractadd!( - a_dest_ref, labels_dest, conj.(a1), labels1, conj.(a2), labels2, α, β - ) - @test a_dest ≈ a_dest_ref - - # compare against TensorOperations backend - a_dest_to = copy(a_dest_init) - TensorAlgebra.contractopadd!( - a_dest_to, labels_dest, - conj, a1, labels1, - conj, a2, labels2, - α, β; alg = alg_tensoroperations - ) - @test a_dest ≈ a_dest_to - end + # conj on first input + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, conj.(a1), labels1, a2, labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + + # conj on second input + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, a1, labels1, conj.(a2), labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + + # conj on both inputs + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!( + a_dest_ref, labels_dest, conj.(a1), labels1, conj.(a2), labels2, α, β + ) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to end end @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, From 8b5c0f310990f641cc9ad70954c0e126d3d9ce4e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 14:08:12 -0400 Subject: [PATCH 23/26] Drop aliasing language from permutedimsop docstring --- src/matricize.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 8ce0097..3cd5bc4 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -133,8 +133,7 @@ end permutedimsop(op, src, perm_codomain, perm_domain) Non-mutating version of `bipermutedimsopadd!`: returns -`op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias" -semantics — the result may be a view/wrapper aliasing `src` or a fresh copy. +`op.(permutedims(src, (perm_codomain..., perm_domain...)))`. """ function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) dest = allocate_output(permutedimsop, op, src, perm_codomain, perm_domain) From cab0306828ff14b5ced59bbc53ecc7fe269c14d1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 14:22:35 -0400 Subject: [PATCH 24/26] Bump to v0.9.0 (breaking: bipermutedimsopadd! is now primary permutation primitive) --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index a7089bb..67f8556 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.8.1" +version = "0.9.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/test/Project.toml b/test/Project.toml index 3f95e41..d681c86 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.8.1" +TensorAlgebra = "0.9" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From d6acabf662aa8e080deb8dbe2055ad053cd0fa41 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 14:45:07 -0400 Subject: [PATCH 25/26] Remove TensorAlgebra compat from test/ (workspace + [sources] handles it) --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index d681c86..b2c7ae4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,6 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.9" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From 3e3352358f4c273f3336415c3fe8f785a85de539 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Apr 2026 14:48:41 -0400 Subject: [PATCH 26/26] Bump TensorAlgebra compat to 0.9 in all workspace subdirs (docs, examples, test) --- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 20650d9..85fdaac 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" diff --git a/examples/Project.toml b/examples/Project.toml index a800625..87b72a8 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" diff --git a/test/Project.toml b/test/Project.toml index b2c7ae4..d681c86 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,6 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" +TensorAlgebra = "0.9" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1"