Skip to content

Commit 5a0dfab

Browse files
mtfishmanclaude
andauthored
Guard β=0 in 0-dim bipermutedimsopadd! and drop empty-array workaround (#173)
## Summary - The 0-dim branch of `bipermutedimsopadd!` read `dest[]` unconditionally (`β * dest[] + α * op(src[])`), even when `β = 0`. By BLAS convention, `β = 0` means `dest` is treated as write-only — its contents need not be defined. With element types whose `undef` storage is unreadable, the unguarded read crashed: `bipermutedimsopadd!(Array{BigFloat, 0}(undef), identity, src, (), (), true, false)` threw `UndefRefError`, since the unassigned slot of a 0-dim array of mutable `BigFloat` cannot be read. This surfaced via callers that allocate a 0-dim destination via `similar` and call into `bipermutedimsopadd!` with `β = 0` (e.g. scalar contractions producing a 0-dim result). - Adds an `iszero(β)` guard to the 0-dim branch to skip reading `dest` in the write-only case. The 0-dim short-circuit itself is kept: it lets downstream array types (e.g. `BlockSparseArray{T, 0}`) use direct `dest[]`/`src[]` accesses instead of having to support `getindex` on a 0-dim `PermutedDimsArray` wrapper around their type. - The companion `isempty(dest) && return dest` early return worked around a `Strided.jl` broadcasting bug fixed by [QuantumKitHub/Strided.jl#50](QuantumKitHub/Strided.jl#50), released in `Strided` 2.3.5. Adds `Strided` to `[deps]` (with `using Strided: Strided` so it isn't reported as stale) and a `Strided = "2.3.5"` compat floor, then drops the workaround. - Adds a regression test covering `Array{T, 0}` for `T ∈ {Float64, BigFloat}` with both `β = 0` (write-only) and `β ≠ 0` (accumulating) cases. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b4f43b6 commit 5a0dfab

3 files changed

Lines changed: 37 additions & 11 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.9.1"
3+
version = "0.9.2"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]
@@ -11,6 +11,7 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1111
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
14+
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1415
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
1516
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1617
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
@@ -35,6 +36,7 @@ GPUArraysCore = "0.2"
3536
LinearAlgebra = "1.10"
3637
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
3738
Mooncake = "0.4.202, 0.5"
39+
Strided = "2.3.5"
3840
StridedViews = "0.4.1, 0.5"
3941
TensorOperations = "5"
4042
TupleTools = "1.6"

src/permutedimsadd.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import StridedViews as SV
22
using FunctionImplementations: permuteddims
3+
using Strided: Strided
34

45
# Specify if an array is on CPU. This is helpful for backends that don't support
56
# operations on GPU, such as Strided.jl.
@@ -50,19 +51,22 @@ function bipermutedimsopadd!(
5051
perm = (perm_codomain..., perm_domain...)
5152
check_input(bipermutedimsopadd!, dest, src, perm_codomain, perm_domain)
5253

53-
# TODO: Remove this 0-dimensional special case once GradedArray is its own type
54-
# (not an alias for BlockSparseArray), so the GradedArray overload catches the
55-
# 0-dimensional contraction result.
54+
# 0-dim short-circuit: avoid the permute-broadcast path entirely so that
55+
# downstream array types (e.g. `BlockSparseArray{T, 0}`) don't have to define
56+
# `getindex` on a 0-dim `PermutedDimsArray` wrapper around them.
57+
# The `iszero(β)` guard follows the BLAS convention that `β = 0` means `dest`
58+
# is write-only — its slot need not be defined. This matters for element types
59+
# whose `undef` storage is unreadable, e.g. `Array{BigFloat, 0}(undef)[]` throws
60+
# `UndefRefError`.
5661
if iszero(ndims(dest))
57-
dest[] = β * dest[] + α * op(src[])
62+
if iszero(β)
63+
dest[] = α * op(src[])
64+
else
65+
dest[] = β * dest[] + α * op(src[])
66+
end
5867
return dest
5968
end
6069

61-
# This works around a bug in Strided.jl v2.3.4 and below when broadcasting
62-
# empty StridedViews: https://github.com/QuantumKitHub/Strided.jl/pull/50
63-
# TODO: Delete this and bump the version of Strided.jl once that is fixed.
64-
isempty(dest) && return dest
65-
6670
dest′, src′ = maybestrided(dest, permuteddims(src, perm))
6771
if op === identity
6872
if iszero(β)

test/test_permutedimsadd.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Adapt: adapt
22
using JLArrays: JLArray
3-
using TensorAlgebra: add!, permutedimsadd!, permutedimsopadd!
3+
using TensorAlgebra: add!, bipermutedimsopadd!, permutedimsadd!, permutedimsopadd!
44
using Test: @test, @testset
55

66
@testset "[permutedims]add!" begin
@@ -47,6 +47,26 @@ using Test: @test, @testset
4747
@test b′ β * b + α * permutedims(a, perm)
4848
end
4949
end
50+
@testset "bipermutedimsopadd! 0-dim with β=0 must not read dest (eltype=$T)" for T in
51+
(
52+
Float64,
53+
BigFloat,
54+
)
55+
# With β=0, `dest` is write-only by BLAS convention; its contents need not be
56+
# defined. For element types whose `undef` storage is unreadable (e.g. mutable
57+
# `BigFloat`), reading the slot would throw `UndefRefError`.
58+
src = fill(T(7))
59+
for op in (identity, conj)
60+
dest = Array{T, 0}(undef)
61+
bipermutedimsopadd!(dest, op, src, (), (), true, false)
62+
@test dest[] == op(src[])
63+
end
64+
# With β nonzero, both reads and writes go through with the accumulating
65+
# semantics `dest = β * dest + α * op(src)`.
66+
dest = fill(T(2))
67+
bipermutedimsopadd!(dest, identity, src, (), (), T(3), T(5))
68+
@test dest[] == 3 * 7 + 5 * 2
69+
end
5070
@testset "permutedimsopadd! (arraytype=$arrayt)" for arrayt in (Array,)
5171
dev = adapt(arrayt)
5272
a = dev(randn(ComplexF64, 2, 2, 2))

0 commit comments

Comments
 (0)