Skip to content

Commit 652e4ea

Browse files
committed
Add bipermutedimsopadd! overloads for GradedArrays (companion to ITensor/TensorAlgebra.jl#158)
1 parent d470cf8 commit 652e4ea

3 files changed

Lines changed: 44 additions & 12 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3-
version = "0.8.3"
3+
version = "0.8.4"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

src/GradedArrays.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ using BlockSparseArrays: BlockSparseArrays, blockdiagindices, blockstoredlength,
3030
using KroneckerArrays: KroneckerArrays, kroneckerfactors, ×,
3131
using LinearAlgebra: LinearAlgebra, Adjoint, mul!
3232
using SparseArraysBase: SparseArraysBase
33-
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, matricize, matricize_axes,
34-
permutedimsadd!, permutedimsopadd!, tensor_product_axis, trivial_axis, trivialbiperm,
35-
tryflattenlinear, unmatricize
33+
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, bipermutedimsopadd!,
34+
matricize, matricize_axes, permutedimsadd!, permutedimsopadd!, tensor_product_axis,
35+
trivial_axis, trivialbiperm, tryflattenlinear, unmatricize
3636
using TensorKitSectors: TensorKitSectors as TKS
3737
using TypeParameterAccessors: type_parameters, unspecify_type_parameters
3838

src/tensoralgebra.jl

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,26 +151,33 @@ function tensor_product(r::SectorOneTo, s::TKS.Sector)
151151
return tensor_product(to_gradedrange(r), to_gradedrange(s))
152152
end
153153

154-
# ======================== permutedimsopadd! ========================
154+
# ======================== bipermutedimsopadd! ========================
155+
# Primary overloads. The flat-perm permutedimsopadd! overloads forward here.
155156

156-
function TensorAlgebra.permutedimsopadd!(
157-
y::AbstractSectorArray, op, x::AbstractSectorArray, perm,
157+
function TensorAlgebra.bipermutedimsopadd!(
158+
y::AbstractSectorArray, op, x::AbstractSectorArray,
159+
perm_codomain, perm_domain,
158160
α::Number, β::Number
159161
)
162+
perm = (perm_codomain..., perm_domain...)
160163
sector(y) == permutedims(sector(x), perm) || throw(DimensionMismatch())
161164
phase = fermion_permutation_phase(sector(x), perm)
162-
TensorAlgebra.permutedimsopadd!(data(y), op, data(x), perm, phase * α, β)
165+
TensorAlgebra.bipermutedimsopadd!(
166+
data(y), op, data(x), perm_codomain, perm_domain, phase * α, β
167+
)
163168
return y
164169
end
165170

166-
function TensorAlgebra.permutedimsopadd!(
167-
y::AbstractGradedArray{<:Any, N}, op, x::AbstractGradedArray{<:Any, N}, perm,
171+
function TensorAlgebra.bipermutedimsopadd!(
172+
y::AbstractGradedArray{<:Any, N}, op, x::AbstractGradedArray{<:Any, N},
173+
perm_codomain, perm_domain,
168174
α::Number, β::Number
169175
) where {N}
176+
perm = (perm_codomain..., perm_domain...)
170177
# `scale!(y, 0)` doesn't reliably zero `y`: if any block of `y` holds
171178
# `NaN`/`Inf` (uninitialized memory from `undef` allocation or a stale
172179
# garbage value), `NaN * 0 == NaN` keeps it poisoned, and subsequent
173-
# `permutedimsopadd!(..., α, one(α))` calls on a block of `y` that
180+
# `bipermutedimsopadd!(..., α, one(α))` calls on a block of `y` that
174181
# doesn't get visited by the loop below would leak that garbage into the
175182
# result. Allocating broadcasts like `3 * a` go through this path (they
176183
# call with β == 0 on a fresh `similar`-allocated array); before this
@@ -182,7 +189,32 @@ function TensorAlgebra.permutedimsopadd!(
182189
b_dest = Block(ntuple(i -> b[perm[i]], N))
183190
y_b = view(y, Tuple(b_dest)...)
184191
x_b = x[bI]
185-
TensorAlgebra.permutedimsopadd!(y_b, op, x_b, perm, α, one(α))
192+
TensorAlgebra.bipermutedimsopadd!(
193+
y_b,
194+
op,
195+
x_b,
196+
perm_codomain,
197+
perm_domain,
198+
α,
199+
one(α)
200+
)
186201
end
187202
return y
188203
end
204+
205+
# ======================== permutedimsopadd! ========================
206+
# Flat-perm overloads forward to bipermutedimsopadd! with perm_domain = ().
207+
208+
function TensorAlgebra.permutedimsopadd!(
209+
y::AbstractSectorArray, op, x::AbstractSectorArray, perm,
210+
α::Number, β::Number
211+
)
212+
return TensorAlgebra.bipermutedimsopadd!(y, op, x, perm, (), α, β)
213+
end
214+
215+
function TensorAlgebra.permutedimsopadd!(
216+
y::AbstractGradedArray{<:Any, N}, op, x::AbstractGradedArray{<:Any, N}, perm,
217+
α::Number, β::Number
218+
) where {N}
219+
return TensorAlgebra.bipermutedimsopadd!(y, op, x, perm, (), α, β)
220+
end

0 commit comments

Comments
 (0)