diff --git a/Project.toml b/Project.toml index 3c65eb0d..67f85561 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.8.0" +version = "0.9.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 20650d9d..85fdaac1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" diff --git a/examples/Project.toml b/examples/Project.toml index a8006256..87b72a8d 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" diff --git a/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl index 8e3e9421..778b7f58 100644 --- a/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt/TensorAlgebraTensorOperationsExt.jl @@ -52,35 +52,22 @@ function TA.contract( end # in-place -function TA.contractadd!( +function TA.contractopadd!( algorithm::TensorOperationsAlgorithm, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) permblocks1 = Tuple.((perm1_codomain, perm1_domain)) permblocks2 = Tuple.((perm2_codomain, perm2_domain)) permblocks_dest = Tuple.((perm_dest_codomain, perm_dest_domain)) - conj1, conj2 = false, false - return TO.tensorcontract!( - a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2, - permblocks_dest, α, β, algorithm.backend - ) -end - -function TA.contractadd!( - algorithm::TensorOperationsAlgorithm, - a_dest::AbstractArray, labels_dest, - a1::AbstractArray, labels1, - a2::AbstractArray, labels2, - α::Number, β::Number - ) - permblocks1, permblocks2, permblocks_dest = - TO.contract_indices(labels1, labels2, labels_dest) - conj1, conj2 = false, false + conj1 = op1 === conj + conj2 = op2 === conj + a1′ = (op1 === identity || op1 === conj) ? a1 : op1.(a1) + a2′ = (op2 === identity || op2 === conj) ? a2 : op2.(a2) return TO.tensorcontract!( - a_dest, a1, permblocks1, conj1, a2, permblocks2, conj2, + a_dest, a1′, permblocks1, conj1, a2′, permblocks2, conj2, permblocks_dest, α, β, algorithm.backend ) end @@ -96,14 +83,13 @@ function TO.tensorcontract!( backend::TA.ContractAlgorithm, allocator ) - # TODO: FIXME: Use `conjed` to do the conjugation lazily. - a1′ = conj1 ? conj(a1) : a1 - a2′ = conj2 ? conj(a2) : a2 - return TA.contractadd!( + op1 = conj1 ? conj : identity + op2 = conj2 ? conj : identity + return TA.contractopadd!( backend, a_dest, permblocks_dest..., - a1′, permblocks1..., - a2′, permblocks2..., + op1, a1, permblocks1..., + op2, a2, permblocks2..., α, β ) end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index a1c966a0..5202364f 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -3,6 +3,10 @@ module TensorAlgebra export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals +if VERSION >= v"1.11.0-DEV.469" + eval(Meta.parse("public contractopadd!, matricizeop")) +end + include("MatrixAlgebra.jl") include("blockedtuple.jl") include("blockedpermutation.jl") diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 87231b95..6fe5815d 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -87,19 +87,48 @@ function contractadd!( α::Number, β::Number; kwargs... ) - biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) - return contractadd!( - a_dest, blocks(biperm_dest)..., - a1, blocks(biperm1)..., - a2, blocks(biperm2)..., - α, β; kwargs... + return contractopadd!( + a_dest, labels_dest, identity, a1, labels1, identity, a2, labels2, α, β; kwargs... ) end +# contractadd! (bipartitioned permutations) function contractadd!( a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, a1::AbstractArray, perm1_codomain, perm1_domain, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number; + kwargs... + ) + return contractopadd!( + a_dest, perm_dest_codomain, perm_dest_domain, + identity, a1, perm1_codomain, perm1_domain, + identity, a2, perm2_codomain, perm2_domain, + α, β; kwargs... + ) +end + +# contractopadd! (labels) +function contractopadd!( + a_dest::AbstractArray, labels_dest, + op1, a1::AbstractArray, labels1, + op2, a2::AbstractArray, labels2, + α::Number, β::Number; + kwargs... + ) + biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2) + return contractopadd!( + a_dest, blocks(biperm_dest)..., + op1, a1, blocks(biperm1)..., + op2, a2, blocks(biperm2)..., + α, β; kwargs... + ) +end +# contractopadd! (bipartitioned permutations, algorithm selection) +function contractopadd!( + a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, + α::Number, β::Number; alg = DefaultContractAlgorithm(), kwargs... ) check_input( @@ -109,38 +138,38 @@ function contractadd!( a2, perm2_codomain, perm2_domain ) algorithm = select_contract_algorithm(alg, a1, a2; kwargs...) - return contractadd!( + return contractopadd!( algorithm, a_dest, perm_dest_codomain, perm_dest_domain, - a1, perm1_codomain, perm1_domain, - a2, perm2_codomain, perm2_domain, + op1, a1, perm1_codomain, perm1_domain, + op2, a2, perm2_codomain, perm2_domain, α, β ) end -# contractadd! (dispatched on the algorithm, bipartitioned permutations) +# contractopadd! (dispatched on the algorithm, bipartitioned permutations) # Required interface if not using matricized contraction -function contractadd!( +function contractopadd!( algorithm::ContractAlgorithm, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) return throw( MethodError( - contractadd!, + contractopadd!, ( algorithm, a_dest, perm_dest_codomain, perm_dest_domain, - a1, perm1_codomain, perm1_domain, - a2, perm2_codomain, perm2_domain, + op1, a1, perm1_codomain, perm1_domain, + op2, a2, perm2_codomain, perm2_domain, α, β, ) ) ) end -# BlockPermutation versions of contract[add][!] +# BlockPermutation versions of contract[opadd][!] function contract( a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, biperm2::AbstractBlockPermutation{2}; @@ -187,18 +216,16 @@ function contractadd!( α, β; kwargs... ) end -function contractadd!( - algorithm::ContractAlgorithm, +function contractopadd!( a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, - a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, - α::Number, β::Number + op1, a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + op2, a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, + α::Number, β::Number; kwargs... ) - return contractadd!( - algorithm, + return contractopadd!( a_dest, blocks(biperm_dest)..., - a1, blocks(biperm1)..., - a2, blocks(biperm2)..., - α, β + op1, a1, blocks(biperm1)..., + op2, a2, blocks(biperm2)..., + α, β; kwargs... ) end diff --git a/src/contract/contract_matricize.jl b/src/contract/contract_matricize.jl index 701dd4f2..ed31bcf6 100644 --- a/src/contract/contract_matricize.jl +++ b/src/contract/contract_matricize.jl @@ -1,26 +1,26 @@ using LinearAlgebra: mul! -function contractadd!( +function contractopadd!( algorithm::Matricize, a_dest::AbstractArray, biperm_dest_codomain, biperm_dest_domain, - a1::AbstractArray, biperm1_codomain, biperm1_domain, - a2::AbstractArray, biperm2_codomain, biperm2_domain, + op1, a1::AbstractArray, biperm1_codomain, biperm1_domain, + op2, a2::AbstractArray, biperm2_codomain, biperm2_domain, α::Number, β::Number ) - return contractadd!_matricize( + return contractopadd!_matricize( algorithm, a_dest, biperm_dest_codomain, biperm_dest_domain, - a1, biperm1_codomain, biperm1_domain, - a2, biperm2_codomain, biperm2_domain, + op1, a1, biperm1_codomain, biperm1_domain, + op2, a2, biperm2_codomain, biperm2_domain, α, β ) end -function contractadd!_matricize( +function contractopadd!_matricize( algorithm::Matricize, a_dest::AbstractArray, perm_dest_codomain, perm_dest_domain, - a1::AbstractArray, perm1_codomain, perm1_domain, - a2::AbstractArray, perm2_codomain, perm2_domain, + op1, a1::AbstractArray, perm1_codomain, perm1_domain, + op2, a2::AbstractArray, perm2_codomain, perm2_domain, α::Number, β::Number ) perm_dest = (perm_dest_codomain..., perm_dest_domain...) @@ -32,8 +32,8 @@ function contractadd!_matricize( a1, perm1_codomain, perm1_domain, a2, perm2_codomain, perm2_domain ) - a1_mat = matricize(algorithm.fusion_style, a1, perm1_codomain, perm1_domain) - a2_mat = matricize(algorithm.fusion_style, a2, perm2_codomain, perm2_domain) + a1_mat = matricizeop(algorithm.fusion_style, op1, a1, perm1_codomain, perm1_domain) + a2_mat = matricizeop(algorithm.fusion_style, op2, a2, perm2_codomain, perm2_domain) a_dest_mat = a1_mat * a2_mat unmatricizeadd!( algorithm.fusion_style, a_dest, a_dest_mat, invperm_codomain, invperm_domain, α, β diff --git a/src/matricize.jl b/src/matricize.jl index dc2c6de1..3cd5bc47 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -123,16 +123,40 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val) return matricize_axes(FusionStyle(a), a, ndims_codomain) end +# Default similar with bipartitioned axes: flatten to a plain tuple of axes. +# Downstream types (e.g., FusionTensor) can override to preserve bipartition. +function Base.similar(a::AbstractArray, T::Type, axes::BlockedTuple{2}) + return similar(a, T, Tuple(axes)) +end + +""" + permutedimsop(op, src, perm_codomain, perm_domain) + +Non-mutating version of `bipermutedimsopadd!`: returns +`op.(permutedims(src, (perm_codomain..., perm_domain...)))`. +""" +function permutedimsop(op, src::AbstractArray, perm_codomain, perm_domain) + dest = allocate_output(permutedimsop, op, src, perm_codomain, perm_domain) + return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false) +end + +function allocate_output(::typeof(permutedimsop), op, src::AbstractArray, perm_co, perm_do) + T = Base.promote_op(op, eltype(src)) + axes_co = map(i -> axes(src, i), perm_co) + axes_do = map(i -> axes(src, i), perm_do) + return similar(src, T, tuplemortar((axes_co, axes_do))) +end + # Inner version takes a list of sub-permutations, overload this one if needed. # TODO: Remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. # TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`. # Keeping it here for backwards compatibility. function bipermutedims(a::AbstractArray, perm1, perm2) - return _permutedims(a, (perm1..., perm2...)) + return permutedimsop(identity, a, perm1, perm2) end function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) - return _permutedims!(a_dest, a_src, (perm1..., perm2...)) + return bipermutedimsopadd!(a_dest, identity, a_src, perm1, perm2, true, false) end function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) return bipermutedims(a, blocks(biperm)...) @@ -165,15 +189,14 @@ function matricize( ) return matricize(FusionStyle(a), a, perm_codomain, perm_domain) end -# This is a more advanced version to overload where the permutation is actually performed. +# Thin wrapper around `matricizeop` with identity op — the actual matricization logic +# (and the fusion-style overload point for folding ops into matricization) lives in +# `matricizeop`. function matricize( style::FusionStyle, a::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} ) - ndims(a) == length(perm_codomain) + length(perm_domain) || - throw(ArgumentError("Invalid bipermutation")) - a_perm = bipermutedims(a, perm_codomain, perm_domain) - return matricize(style, a_perm, Val(length(perm_codomain))) + return matricizeop(style, identity, a, perm_codomain, perm_domain) end # Process inputs such as `EllipsisNotation.Ellipsis`. @@ -218,6 +241,39 @@ function matricize( return matricize(style, a, blocks(biperm_dest)...) end +# ==================================== matricizeop ======================================= + +""" + matricizeop(op, a, perm_codomain, perm_domain) + +Matricize `a` with element-wise operation `op` folded in. Returns a matrix representing +`op.(matricize(a, perm_codomain, perm_domain))`. + +Has "maybe alias" semantics: the result may be a view/wrapper aliasing `a` or a fresh +copy, depending on the fusion style and array type. The caller should treat the result +as read-only. +""" +function matricizeop(op, a::AbstractArray, perm_codomain, perm_domain) + return matricizeop(FusionStyle(a), op, a, perm_codomain, perm_domain) +end +function matricizeop( + style::FusionStyle, op, a::AbstractArray, perm_codomain, perm_domain + ) + return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...) +end +# This is the primary function that should be overloaded for new fusion styles to fold +# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy +# wrappers like StridedView with op metadata for zero-copy). +function matricizeop( + style::FusionStyle, op, a::AbstractArray, + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}} + ) + ndims(a) == length(perm_codomain) + length(perm_domain) || + throw(ArgumentError("Invalid bipermutation")) + a_perm_op = permutedimsop(op, a, perm_codomain, perm_domain) + return matricize(style, a_perm_op, Val(length(perm_codomain))) +end + # ==================================== unmatricize ======================================= # This is the primary function that should be overloaded for new fusion styles. function unmatricize( diff --git a/src/permutedimsadd.jl b/src/permutedimsadd.jl index dac9490a..eb21f0a9 100644 --- a/src/permutedimsadd.jl +++ b/src/permutedimsadd.jl @@ -10,27 +10,49 @@ function maybestrided(as::AbstractArray...) end # ---------------------------------------------------------------------------- # -# permutedimsopadd! — the single materialization primitive +# bipermutedimsopadd! — the primary materialization primitive # ---------------------------------------------------------------------------- # +function bipermutedimsopadd! end + +function check_input( + ::typeof(bipermutedimsopadd!), dest::AbstractArray, src::AbstractArray, + perm_codomain, perm_domain + ) + perm = (perm_codomain..., perm_domain...) + ndims(dest) == length(perm) || + throw(DimensionMismatch("destination ndims does not match permutation length")) + axes(dest) == ntuple(d -> axes(src, perm[d]), ndims(dest)) || + throw(DimensionMismatch("destination axes do not match permuted source axes")) + return nothing +end + """ - permutedimsopadd!(dest, op, src, perm, α, β) + bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β) -`dest = β * dest + α * permutedims(op.(src), perm)`. +`dest = β * dest + α * permutedims(op.(src), (perm_codomain..., perm_domain...))`. -This is the single materialization primitive for `LinearBroadcasted` types. -Downstream array types should implement this function. The `op` is an element-wise -linear map (e.g., `identity`, `conj`, `adjoint`, `transpose`, `Float32`). +This is the primary overload point for downstream array types that want to +implement op-aware bipartitioned permutation + accumulation (e.g., fuse `conj` +into the copy, or use lazy wrappers like `StridedView` with op metadata). -The default implementation applies `op` element-wise, permutes, then accumulates -via broadcasting with Strided.jl optimization when possible. +The `op` is an element-wise linear map (e.g., `identity`, `conj`). + +The default implementation flattens the bipartitioned permutation, applies `op` +element-wise, permutes, then accumulates via broadcasting with Strided.jl +optimization when possible. """ -function permutedimsopadd!( - dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number +function bipermutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, + perm_codomain, perm_domain, + α::Number, β::Number ) + perm = (perm_codomain..., perm_domain...) + check_input(bipermutedimsopadd!, dest, src, perm_codomain, perm_domain) + # TODO: Remove this 0-dimensional special case once GradedArray is its own type - # (not an alias for BlockSparseArray), so the GradedArray permutedimsopadd! overload - # catches the 0-dimensional contraction result. + # (not an alias for BlockSparseArray), so the GradedArray overload catches the + # 0-dimensional contraction result. if iszero(ndims(dest)) dest[] = β * dest[] + α * op(src[]) return dest @@ -58,6 +80,26 @@ function permutedimsopadd!( return dest end +# ---------------------------------------------------------------------------- # +# permutedimsopadd! — flat-permutation interface +# ---------------------------------------------------------------------------- # + +""" + permutedimsopadd!(dest, op, src, perm, α, β) + +`dest = β * dest + α * permutedims(op.(src), perm)`. + +This is the single materialization primitive for `LinearBroadcasted` types. +Downstream array types should implement `bipermutedimsopadd!` for the +bipartitioned permutation version; this flat-permutation overload forwards to it +with `perm_domain = ()`. +""" +function permutedimsopadd!( + dest::AbstractArray, op, src::AbstractArray, perm, α::Number, β::Number + ) + return bipermutedimsopadd!(dest, op, src, perm, (), α, β) +end + # ---------------------------------------------------------------------------- # # Convenience functions that lower to permutedimsopadd! # ---------------------------------------------------------------------------- # diff --git a/test/Project.toml b/test/Project.toml index 5c0c3426..d681c867 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.8" +TensorAlgebra = "0.9" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_basics.jl b/test/test_basics.jl index e48be7ff..ecc27b71 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,3 +1,4 @@ +import TensorAlgebra using EllipsisNotation: var".." using StableRNGs: StableRNG using TensorAlgebra: BlockedTuple, ContractAlgorithm, bipermutedims, bipermutedims!, @@ -85,6 +86,38 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_fused ≈ ones(elt, 1, 1) end + @testset "matricizeop (eltype=$elt)" for elt in elts + rng = StableRNG(123) + a = randn(rng, elt, 2, 3, 4) + + # identity op: should match matricize exactly + m = TensorAlgebra.matricizeop(identity, a, (1,), (2, 3)) + m_ref = matricize(a, (1,), (2, 3)) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(identity, a, (3, 1), (2,)) + m_ref = matricize(a, (3, 1), (2,)) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(identity, a, (2, 3), (1,)) + m_ref = matricize(a, (2, 3), (1,)) + @test m ≈ m_ref + + # conj op + m = TensorAlgebra.matricizeop(conj, a, (1,), (2, 3)) + m_ref = conj.(matricize(a, (1,), (2, 3))) + @test m ≈ m_ref + + m = TensorAlgebra.matricizeop(conj, a, (3, 1), (2,)) + m_ref = conj.(matricize(a, (3, 1), (2,))) + @test m ≈ m_ref + + # general op + m = TensorAlgebra.matricizeop(abs, a, (1,), (2, 3)) + m_ref = abs.(matricize(a, (1,), (2, 3))) + @test m ≈ m_ref + end + @testset "unmatricize (eltype=$elt)" for elt in elts a0 = randn(elt, 2, 3, 4, 5) axes0 = axes(a0) @@ -248,6 +281,109 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) ) end + @testset "contractopadd! (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + elt_dest = promote_type(elt1, elt2) + dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) + labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) + rng = StableRNG(123) + for (d1s, d2s, d_dests) in ( + ((1, 2), (2, 3), (1, 3)), + ((1, 2), (2, 3), (3, 1)), + ((1, 2, 3), (2, 3, 4), (1, 4)), + ((3, 2, 1), (4, 2, 3), (4, 1)), + ((1, 2, 3), (3, 4), (2, 4, 1)), + ) + a1 = randn(rng, elt1, map(i -> dims[i], d1s)) + labels1 = map(i -> labels[i], d1s) + a2 = randn(rng, elt2, map(i -> dims[i], d2s)) + labels2 = map(i -> labels[i], d2s) + labels_dest = map(i -> labels[i], d_dests) + + α = elt_dest(1.2) + β = elt_dest(2.4) + a_dest_init = randn(rng, elt_dest, map(i -> dims[i], d_dests)) + + # identity ops should match contractadd! + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + identity, a1, labels1, + identity, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, a1, labels1, a2, labels2, α, β) + @test a_dest ≈ a_dest_ref + + # conj on first input + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, conj.(a1), labels1, a2, labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + identity, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + + # conj on second input + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!(a_dest_ref, labels_dest, a1, labels1, conj.(a2), labels2, α, β) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + identity, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + + # conj on both inputs + a_dest = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β + ) + a_dest_ref = copy(a_dest_init) + contractadd!( + a_dest_ref, labels_dest, conj.(a1), labels1, conj.(a2), labels2, α, β + ) + @test a_dest ≈ a_dest_ref + + # compare against TensorOperations backend + a_dest_to = copy(a_dest_init) + TensorAlgebra.contractopadd!( + a_dest_to, labels_dest, + conj, a1, labels1, + conj, a2, labels2, + α, β; alg = alg_tensoroperations + ) + @test a_dest ≈ a_dest_to + end + end @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts diff --git a/test/test_exports.jl b/test/test_exports.jl index f7f9bda9..0fb0d9b4 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -22,6 +22,10 @@ using Test: @test, @testset :svd, :svdvals, ] + # `public` (Julia 1.11+) adds names to `names()`; include them on 1.11+. + if VERSION >= v"1.11.0-DEV.469" + append!(exports, [:contractopadd!, :matricizeop]) + end @test issetequal(names(TensorAlgebra), exports) exports = [