Skip to content

Commit fd324eb

Browse files
mtfishmanclaude
andauthored
Return dimnames as a Vector and the named axes as Tuples, dropping LittleSet (#171)
## Summary `dimnames` now returns the stored dimension names as a `Vector`, and the named axes (`inds`, `axes`, `size`) return `Tuple`s, replacing the `LittleSet` wrapper the accessors previously returned. `LittleSet` was an ordered-set type whose real jobs were a set-valued `==` and a custom broadcast style for name-aligned shape combination. With the accessors returning plain `Vector`s and `Tuple`s, the spots that relied on the set-valued equality (`==`, `isequal`, `isapprox`, and the multi-argument `eachindex`) now call `issetequal` directly, and the broadcast shape machinery (`promote_shape`, `broadcast_shape`, `check_broadcast_shape`) dispatches on the tuple of named ranges. `LittleSet` is removed. Returning `dimnames` as the stored `Vector` exposed a latent autodiff issue. The accessor carried a `@zero_derivative` Mooncake rule that was only correct because the old `dimnames` built a fresh object, and `@zero_derivative` is documented to be unsafe when a function's output aliases one of its fields. The rule is dropped, so Mooncake differentiates the field access through its built-in `getfield` rule, which preserves the aliasing and yields a zero gradient for the non-differentiable names. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 5044ef9 commit fd324eb

7 files changed

Lines changed: 43 additions & 131 deletions

File tree

ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ Mooncake.tangent_type(::Type{<:AbstractNamedUnitRange}) = Mooncake.NoTangent
1111
@zero_derivative DefaultCtx Tuple{typeof(blockedperm), AbstractITensor, Any, Any}
1212
@zero_derivative DefaultCtx Tuple{typeof(blockedperm_nameddims), Any, Any, Any}
1313
@zero_derivative DefaultCtx Tuple{typeof(combine_nameddimsconstructors), Any, Any}
14-
@zero_derivative DefaultCtx Tuple{typeof(dimnames), Any}
14+
# `dimnames(::ITensor)` returns the stored names `Vector` directly, so its output
15+
# aliases a field, where `@zero_derivative` is documented to be incorrect. Let
16+
# Mooncake differentiate it through the underlying `getfield`, whose built-in rule
17+
# preserves the aliasing (the names are non-differentiable, so the result is zero).
1518
@zero_derivative DefaultCtx Tuple{typeof(dimnames), Any, Any}
1619
@zero_derivative DefaultCtx Tuple{typeof(dimnames_setdiff), Any, Any}
1720
@zero_derivative DefaultCtx Tuple{typeof(inds), Any}

src/ITensorBase.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using Compat: @compat
77
@compat public @names
88

99
# Named-array machinery (relocated from NamedDimsArrays.jl).
10-
include("littleset.jl")
1110
include("isnamed.jl")
1211
include("randname.jl")
1312
include("abstractnamedinteger.jl")

src/abstractitensor.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ denamed(a::AbstractITensor) = throw(MethodError(denamed, a))
5353
denamed(a::AbstractITensor, inds) = denamed(aligneddims(a, inds))
5454
dename(a::AbstractITensor, inds) = denamed(aligndims(a, inds))
5555

56-
# Output the named axes/indices of the named dims array.
57-
inds(a::AbstractITensor) = LittleSet(named.(axes(denamed(a)), dimnames(a)))
56+
# Output the named axes/indices of the named dims array, as a `Tuple` (even though
57+
# the dimension names are stored as a `Vector`).
58+
inds(a::AbstractITensor) = named.(axes(denamed(a)), Tuple(dimnames(a)))
5859
inds(a::AbstractITensor, dim::Int) = inds(a)[dim]
5960

6061
isnamed(::Type{<:AbstractITensor}) = true
@@ -297,14 +298,6 @@ function Base.similar(
297298
return similar_nameddims(a, elt, inds)
298299
end
299300

300-
function Base.similar(a::AbstractArray, inds::LittleSet)
301-
return similar_nameddims(a, eltype(a), inds)
302-
end
303-
304-
function Base.similar(a::AbstractArray, elt::Type, inds::LittleSet)
305-
return similar_nameddims(a, elt, inds)
306-
end
307-
308301
# Same entry points with a named-tensor prototype. An `AbstractITensor` is no longer
309302
# an `AbstractArray`, so the methods above (which build a named tensor from a plain
310303
# array prototype) no longer cover it.
@@ -320,11 +313,6 @@ function Base.similar(
320313
)
321314
return similar_nameddims(a, elt, inds)
322315
end
323-
Base.similar(a::AbstractITensor, inds::LittleSet) = similar_nameddims(a, eltype(a), inds)
324-
function Base.similar(a::AbstractITensor, elt::Type, inds::LittleSet)
325-
return similar_nameddims(a, elt, inds)
326-
end
327-
328316
function setinds(a::AbstractITensor, inds)
329317
return nameddimsconstructorof(a)(denamed(a), inds)
330318
end
@@ -412,10 +400,6 @@ struct NamedDimsCartesianIndices{
412400
)
413401
end
414402
end
415-
function NamedDimsCartesianIndices(indices::LittleSet)
416-
return NamedDimsCartesianIndices(Tuple(indices))
417-
end
418-
419403
# The element type is no longer carried by the (rank-erased) supertype, so recover
420404
# it from the stored index-tuple parameter.
421405
function Base.eltype(
@@ -462,21 +446,21 @@ end
462446
# Base version ignores dimension names.
463447
# TODO: Use `mapreduce(isequal, &&, a1, a2)`?
464448
function Base.isequal(a1::AbstractITensor, a2::AbstractITensor)
465-
(inds(a1) == inds(a2)) || return false
449+
issetequal(inds(a1), inds(a2)) || return false
466450
return isequal(denamed(a1), denamed(a2, inds(a1)))
467451
end
468452

469453
# Base version ignores dimension names.
470454
# TODO: Use `mapreduce(==, &&, a1, a2)`?
471455
# TODO: Handle `missing` values properly.
472456
function Base.:(==)(a1::AbstractITensor, a2::AbstractITensor)
473-
(inds(a1) == inds(a2)) || return false
457+
issetequal(inds(a1), inds(a2)) || return false
474458
return denamed(a1) == denamed(a2, inds(a1))
475459
end
476460

477461
# Base version ignores dimension names.
478462
function Base.isapprox(a1::AbstractITensor, a2::AbstractITensor; kwargs...)
479-
(inds(a1) == inds(a2)) || return false
463+
issetequal(inds(a1), inds(a2)) || return false
480464
return isapprox(denamed(a1), denamed(a2, inds(a1)); kwargs...)
481465
end
482466

src/broadcast.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using ..ITensorBase: AbstractITensor, AbstractNamedUnitRange, ITensorBase, LittleSet,
2-
dename, denamed, getperm, inds, name, named, nameddimsconstructorof
1+
using ..ITensorBase: AbstractITensor, AbstractNamedUnitRange, ITensorBase, dename, denamed,
2+
getperm, inds, name, named, nameddimsconstructorof
33
using Base.Broadcast: Broadcast as BC, Broadcasted, broadcast_shape, broadcasted,
44
check_broadcast_shape, combine_axes
55
using TensorAlgebra: TensorAlgebra as TA
@@ -33,26 +33,41 @@ function BC.combine_axes(a1::AbstractITensor, a2::AbstractITensor)
3333
end
3434
BC.combine_axes(a::AbstractITensor) = axes(a)
3535

36+
# The named axes are a `Tuple` of `AbstractNamedUnitRange`s. Dispatch the
37+
# name-aware shape combination on that tuple form (the elements are not
38+
# `AbstractUnitRange`s, so Base's positional tuple-shape methods do not apply).
3639
function BC.broadcast_shape(
37-
ax1::LittleSet, ax2::LittleSet, ax_rest::LittleSet...
40+
ax1::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}},
41+
ax2::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}},
42+
ax_rest::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}...
3843
)
3944
return broadcast_shape(broadcast_shape(ax1, ax2), ax_rest...)
4045
end
4146

42-
function BC.broadcast_shape(ax1::LittleSet, ax2::LittleSet)
47+
function BC.broadcast_shape(
48+
ax1::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}},
49+
ax2::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
50+
)
4351
return promote_shape(ax1, ax2)
4452
end
4553

4654
# Handle scalar values.
47-
function BC.broadcast_shape(ax1::Tuple{}, ax2::LittleSet)
55+
function BC.broadcast_shape(
56+
ax1::Tuple{}, ax2::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
57+
)
4858
return ax2
4959
end
50-
function BC.broadcast_shape(ax1::LittleSet, ax2::Tuple{})
60+
function BC.broadcast_shape(
61+
ax1::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}, ax2::Tuple{}
62+
)
5163
return ax1
5264
end
5365

54-
function Base.promote_shape(ax1::LittleSet, ax2::LittleSet)
55-
return LittleSet(set_promote_shape(Tuple(ax1), Tuple(ax2)))
66+
function Base.promote_shape(
67+
ax1::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}},
68+
ax2::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
69+
)
70+
return set_promote_shape(ax1, ax2)
5671
end
5772

5873
function set_promote_shape(
@@ -83,8 +98,11 @@ function set_promote_shape(
8398
return ax1
8499
end
85100

86-
function BC.check_broadcast_shape(ax1::LittleSet, ax2::LittleSet)
87-
return set_check_broadcast_shape(Tuple(ax1), Tuple(ax2))
101+
function BC.check_broadcast_shape(
102+
ax1::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}},
103+
ax2::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
104+
)
105+
return set_check_broadcast_shape(ax1, ax2)
88106
end
89107

90108
function set_check_broadcast_shape(

src/itensor.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ end
1616
ITensor(a::AbstractITensor, inds) = throw(ArgumentError("Already named."))
1717
ITensor(a::AbstractITensor) = ITensor(denamed(a), dimnames(a))
1818

19-
# Minimal interface. The dimnames are stored as a `Vector{DimName}`, but the
20-
# accessor still returns a `LittleSet` over a `Tuple` (unchanged public behavior).
21-
dimnames(a::ITensor) = LittleSet(Tuple(a.dimnames))
19+
# Minimal interface. The dimnames are stored as (and returned as) a `Vector`.
20+
dimnames(a::ITensor) = a.dimnames
2221
denamed(a::ITensor) = a.denamed
2322
Base.parent(a::ITensor) = denamed(a)
2423

src/littleset.jl

Lines changed: 0 additions & 54 deletions
This file was deleted.

test/test_nameddims_basics.jl

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Combinatorics: Combinatorics
2-
using ITensorBase: @names, AbstractITensor, ITensor, LittleSet, Name, NameMismatch,
2+
using ITensorBase: @names, AbstractITensor, ITensor, Name, NameMismatch,
33
NamedDimsCartesianIndex, NamedDimsCartesianIndices, aligndims, aligneddims, apply,
44
dename, denamed, denamedtype, dim, dimnames, dimnametype, dims, fusednames, inds,
55
isnamed, mapinds, name, named, nameddims, namedoneto, product, replacedimnames,
@@ -37,7 +37,7 @@ end
3737
@test inds(na) == (i, j)
3838
@test inds(na, 1) == i
3939
@test inds(na, 2) == j
40-
@test dimnames(na) == ("i", "j")
40+
@test dimnames(na) == ["i", "j"]
4141
@test dimnames(na, 1) == "i"
4242
@test dimnames(na, 2) == "j"
4343
@test dim(na, "i") == 1
@@ -108,9 +108,7 @@ end
108108
j = namedoneto(4, "j")
109109
for na′ in (
110110
similar(na, Float32, (j, i)),
111-
similar(na, Float32, LittleSet((j, i))),
112111
similar(a, Float32, (j, i)),
113-
similar(a, Float32, LittleSet((j, i))),
114112
)
115113
@test eltype(na′) Float32
116114
@test all(inds(na′) .== (j, i))
@@ -123,9 +121,7 @@ end
123121
j = namedoneto(4, "j")
124122
for na′ in (
125123
similar(na, (j, i)),
126-
similar(na, LittleSet((j, i))),
127124
similar(a, (j, i)),
128-
similar(a, LittleSet((j, i))),
129125
)
130126
@test eltype(na′) eltype(na)
131127
@test all(inds(na′) .== (j, i))
@@ -154,7 +150,7 @@ end
154150
a = randn(elt, 2)
155151
na = a[i]
156152
@test na isa ITensor{String}
157-
@test dimnames(na) == ("i",)
153+
@test dimnames(na) == ["i"]
158154
@test denamed(na) == a
159155

160156
# slicing
@@ -416,39 +412,6 @@ end
416412
@test !iszero(na)
417413
end
418414
end
419-
@testset "LittleSet" begin
420-
# Broadcasting
421-
s = LittleSet((1, 2))
422-
@test eltype(s) == Int
423-
@test s .+ [3, 4] == [4, 6]
424-
@test s .+ (3, 4) (4, 6)
425-
426-
s = LittleSet(("a", "b", "c"))
427-
@test all(s .== ("a", "b", "c"))
428-
@test values(s) == ("a", "b", "c")
429-
@test Tuple(s) == ("a", "b", "c")
430-
@test s[1] == "a"
431-
@test s[2] == "b"
432-
@test s[3] == "c"
433-
for s′ in (
434-
replace(x -> x == "b" ? "x" : x, s),
435-
replace(s, "b" => "x"),
436-
map(x -> x == "b" ? "x" : x, s),
437-
)
438-
@test s′ isa LittleSet
439-
@test Tuple(s′) == ("a", "x", "c")
440-
@test s′[1] == "a"
441-
@test s′[2] == "x"
442-
@test s′[3] == "c"
443-
end
444-
445-
s = LittleSet((1, 2, 3))
446-
@test LittleSet(s).values isa Tuple
447-
@test LittleSet(s) == s
448-
sp = LittleSet{NTuple{3, Float64}}(s)
449-
@test eltype(sp) === Float64
450-
@test values(sp) == (1.0, 2.0, 3.0)
451-
end
452415
@testset "show" begin
453416
a = ITensor([1 2; 3 4], ("i", "j"))
454417
@test sprint(show, "text/plain", a) ==

0 commit comments

Comments
 (0)