Skip to content

Commit 818616e

Browse files
authored
Drop the AbstractArray supertype from AbstractITensor (#169)
## Summary `AbstractITensor` is now a standalone abstract type rather than `<: AbstractArray{Any, Any}`. The array-like surface it relied on is supplied directly instead of inherited: `broadcastable` (feeding the existing `ITensorStyle`), elementwise and scalar arithmetic (`+`, `-`, `*`, `/`), iteration, `keys`/`pairs`, `CartesianIndices`, `zero`, `isapprox`, and `normalize!`. The named-tensor verbs that were written on `::AbstractArray` only to ride the subtyping now dispatch on `::AbstractITensor`, which is what their bodies already required (they call `denamed`/`dimnames`): `inds`, `dim`, `dims`, `to_inds`, `aligndims`, `aligneddims`, and the contraction and factorization `*_nameddims` helpers. The methods that build a named tensor from a plain array stay on `::AbstractArray`.
1 parent 0614329 commit 818616e

4 files changed

Lines changed: 116 additions & 41 deletions

File tree

src/abstractitensor.jl

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ using TypeParameterAccessors: unspecify_type_parameters
99
# https://github.com/mcabbott/NamedPlus.jl
1010
# https://pytorch.org/docs/stable/named_tensor.html
1111

12-
abstract type AbstractITensor{DimName} <: AbstractArray{Any, Any} end
12+
abstract type AbstractITensor{DimName} end
1313

14-
# Rank and element type live in the data, not the type. The `<: AbstractArray`
15-
# subtyping is kept for now to reuse generic array functionality, with `Any` for
16-
# both parameters since neither is fixed by the type.
14+
# Rank and element type live in the data, not the type, so the type-level `ndims`
15+
# is `Any` (like `eltype(Array)`). `AbstractITensor` is not an `AbstractArray`: the
16+
# array-like surface it needs (indexing, broadcasting, arithmetic, iteration) is
17+
# supplied directly below rather than inherited.
1718
Base.ndims(::Type{<:AbstractITensor}) = Any
1819

1920
dimnames(a::AbstractITensor) = throw(MethodError(dimnames, a))
@@ -53,15 +54,15 @@ denamed(a::AbstractITensor, inds) = denamed(aligneddims(a, inds))
5354
dename(a::AbstractITensor, inds) = denamed(aligndims(a, inds))
5455

5556
# Output the named axes/indices of the named dims array.
56-
inds(a::AbstractArray) = LittleSet(named.(axes(denamed(a)), dimnames(a)))
57-
inds(a::AbstractArray, dim::Int) = inds(a)[dim]
57+
inds(a::AbstractITensor) = LittleSet(named.(axes(denamed(a)), dimnames(a)))
58+
inds(a::AbstractITensor, dim::Int) = inds(a)[dim]
5859

5960
isnamed(::Type{<:AbstractITensor}) = true
6061

61-
function dim(a::AbstractArray, n)
62+
function dim(a::AbstractITensor, n)
6263
return findfirst(==(name(n)), dimnames(a))
6364
end
64-
dims(a::AbstractArray, ns) = Base.Fix1(dim, a).(ns)
65+
dims(a::AbstractITensor, ns) = Base.Fix1(dim, a).(ns)
6566

6667
dimname_isequal(x) = Base.Fix1(dimname_isequal, x)
6768
dimname_isequal(x, y) = isequal(x, y)
@@ -80,7 +81,7 @@ dimname_isequal(r1, r2::AbstractNamedUnitRange) = r1 == name(r2)
8081
dimname_isequal(r1::AbstractNamedUnitRange, r2::Name) = name(r1) == name(r2)
8182
dimname_isequal(r1::Name, r2::AbstractNamedUnitRange) = name(r1) == name(r2)
8283

83-
function to_inds(a::AbstractArray, dims)
84+
function to_inds(a::AbstractITensor, dims)
8485
is = Base.Fix1(dim, a).(name.(dims))
8586
return Base.Fix1(inds, a).(is)
8687
end
@@ -157,6 +158,11 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
157158
end
158159

159160
Base.copy(a::AbstractITensor) = nameddimsof(a, copy(denamed(a)))
161+
Base.zero(a::AbstractITensor) = nameddimsof(a, zero(denamed(a)))
162+
163+
# `CartesianIndices` of a named tensor is the parent's, via the named axes (as the
164+
# `AbstractArray` fallback did through `axes`).
165+
Base.CartesianIndices(a::AbstractITensor) = CartesianIndices(axes(a))
160166

161167
# Forward `conj` to the underlying so that graded axes flip their sector arrows.
162168
# The default `AbstractArray` fallback would broadcast `conj` over elements without
@@ -169,6 +175,19 @@ Base.conj(a::AbstractITensor) = nameddimsof(a, conj(denamed(a)))
169175
function LinearAlgebra.normalize(a::AbstractITensor, p::Real = 2)
170176
return a / LinearAlgebra.norm(a, p)
171177
end
178+
function LinearAlgebra.normalize!(a::AbstractITensor, p::Real = 2)
179+
LinearAlgebra.normalize!(denamed(a), p)
180+
return a
181+
end
182+
183+
# Elementwise and scalar arithmetic. `AbstractArray` routes these through
184+
# broadcasting; supply them directly now that the supertype is gone.
185+
Base.:+(a1::AbstractITensor, a2::AbstractITensor) = a1 .+ a2
186+
Base.:-(a1::AbstractITensor, a2::AbstractITensor) = a1 .- a2
187+
Base.:-(a::AbstractITensor) = broadcast(-, a)
188+
Base.:*(a::AbstractITensor, x::Number) = a .* x
189+
Base.:*(x::Number, a::AbstractITensor) = x .* a
190+
Base.:/(a::AbstractITensor, x::Number) = a ./ x
172191

173192
# Forward `Random.randn!` / `Random.rand!` to the concrete storage so they
174193
# see the runtime eltype.
@@ -286,6 +305,26 @@ function Base.similar(a::AbstractArray, elt::Type, inds::LittleSet)
286305
return similar_nameddims(a, elt, inds)
287306
end
288307

308+
# Same entry points with a named-tensor prototype. An `AbstractITensor` is no longer
309+
# an `AbstractArray`, so the methods above (which build a named tensor from a plain
310+
# array prototype) no longer cover it.
311+
function Base.similar(
312+
a::AbstractITensor,
313+
inds::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
314+
)
315+
return similar(a, eltype(a), inds)
316+
end
317+
function Base.similar(
318+
a::AbstractITensor, elt::Type,
319+
inds::Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
320+
)
321+
return similar_nameddims(a, elt, inds)
322+
end
323+
Base.similar(a::AbstractITensor, inds::LittleSet) = similar_nameddims(a, eltype(a), inds)
324+
function Base.similar(a::AbstractITensor, elt::Type, inds::LittleSet)
325+
return similar_nameddims(a, elt, inds)
326+
end
327+
289328
function setinds(a::AbstractITensor, inds)
290329
return nameddimsconstructorof(a)(denamed(a), inds)
291330
end
@@ -320,6 +359,18 @@ Base.isempty(a::AbstractITensor) = isempty(denamed(a))
320359
Base.IndexStyle(a::AbstractITensor) = IndexStyle(denamed(a))
321360
Base.eachindex(a::AbstractITensor) = eachindex(denamed(a))
322361

362+
# Iteration, keys, and pairs forward to the parent (these were previously inherited
363+
# from `AbstractArray`).
364+
Base.iterate(a::AbstractITensor, state...) = iterate(denamed(a), state...)
365+
Base.keys(a::AbstractITensor) = keys(denamed(a))
366+
Base.pairs(a::AbstractITensor) = pairs(denamed(a))
367+
368+
# Multi-argument `eachindex` dispatches on the named index style, as the
369+
# `AbstractArray` version did.
370+
function Base.eachindex(a1::AbstractITensor, a_rest::AbstractITensor...)
371+
return eachindex(IndexStyle(a1, a_rest...), a1, a_rest...)
372+
end
373+
323374
# Cartesian indices with names attached.
324375
struct NamedIndexCartesian <: IndexStyle end
325376

@@ -387,7 +438,21 @@ function denamed(I::NamedDimsCartesianIndices)
387438
return CartesianIndices(denamed.(I.indices))
388439
end
389440

390-
function Base.eachindex(::NamedIndexCartesian, a1::AbstractArray, a_rest::AbstractArray...)
441+
# Iterating yields `NamedDimsCartesianIndex`es. The generic `AbstractITensor`
442+
# iteration forwards to `denamed`, which here is a plain `CartesianIndices`, so
443+
# convert each parent index back through `getindex`.
444+
function Base.iterate(I::NamedDimsCartesianIndices, state...)
445+
y = iterate(denamed(I), state...)
446+
isnothing(y) && return nothing
447+
cartesian, next_state = y
448+
return I[Tuple(cartesian)...], next_state
449+
end
450+
451+
function Base.eachindex(
452+
::NamedIndexCartesian,
453+
a1::AbstractITensor,
454+
a_rest::AbstractITensor...
455+
)
391456
all(a -> issetequal(inds(a1), inds(a)), a_rest) ||
392457
throw(NameMismatch("Dimension name mismatch $(inds.((a1, a_rest...)))."))
393458
# TODO: Check the shapes match.
@@ -409,6 +474,12 @@ function Base.:(==)(a1::AbstractITensor, a2::AbstractITensor)
409474
return denamed(a1) == denamed(a2, inds(a1))
410475
end
411476

477+
# Base version ignores dimension names.
478+
function Base.isapprox(a1::AbstractITensor, a2::AbstractITensor; kwargs...)
479+
(inds(a1) == inds(a2)) || return false
480+
return isapprox(denamed(a1), denamed(a2, inds(a1)); kwargs...)
481+
end
482+
412483
# Generalization of `Base.sort` to Tuples for Julia v1.10 compatibility.
413484
# TODO: Remove when we drop support for Julia v1.10.
414485
_sort(x; kwargs...) = sort(x; kwargs...)
@@ -641,7 +712,7 @@ function Base.view(a::AbstractITensor, I::ViewIndex...)
641712
return view_nameddims(a, I...)
642713
end
643714

644-
function getindex_nameddims(a::AbstractArray, I...)
715+
function getindex_nameddims(a::AbstractITensor, I...)
645716
return copy(view(a, I...))
646717
end
647718

@@ -686,7 +757,7 @@ end
686757

687758
# Permute/align dimensions
688759

689-
function aligndims(a::AbstractArray, dims)
760+
function aligndims(a::AbstractITensor, dims)
690761
new_dimnames = name.(dims)
691762
perm = Tuple(getperm(dimnames(a), new_dimnames))
692763
isperm(perm) || throw(
@@ -697,7 +768,7 @@ function aligndims(a::AbstractArray, dims)
697768
return nameddimsconstructorof(a)(permutedims(denamed(a), perm), new_dimnames)
698769
end
699770

700-
function aligneddims(a::AbstractArray, dims)
771+
function aligneddims(a::AbstractITensor, dims)
701772
new_dimnames = name.(dims)
702773
perm = getperm(dimnames(a), new_dimnames)
703774
isperm(perm) || throw(

src/abstractnamedunitrange.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ denamed(c::NamedColon) = Colon()
9797
name(c::NamedColon) = c.name
9898
named(::Colon, name) = NamedColon(name)
9999

100-
struct FirstIndex{Arr <: AbstractArray, Dim}
100+
struct FirstIndex{Arr, Dim}
101101
array::Arr
102102
dim::Dim
103103
end
104104
Base.to_index(i::FirstIndex) = Int(first(axes(i.array, i.dim)))
105105

106-
struct LastIndex{Arr <: AbstractArray, Dim}
106+
struct LastIndex{Arr, Dim}
107107
array::Arr
108108
dim::Dim
109109
end

src/broadcast.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ function BC.BroadcastStyle(arraytype::Type{<:AbstractITensor})
1919
return ITensorStyle{ndims(arraytype), nameddimsconstructorof(arraytype)}()
2020
end
2121

22+
# An `AbstractITensor` broadcasts as itself (previously inherited from
23+
# `AbstractArray`); without this the default `broadcastable` wraps it in a `Ref`.
24+
BC.broadcastable(a::AbstractITensor) = a
25+
2226
function BC.combine_axes(
2327
a1::AbstractITensor, a_rest::AbstractITensor...
2428
)
@@ -95,7 +99,7 @@ end
9599
set_check_broadcast_shape(ax1::Tuple{}, ax2::Tuple{}) = nothing
96100

97101
broadcasted_denamed(x::Number, inds) = x
98-
broadcasted_denamed(a::AbstractArray, inds) = denamed(a, inds)
102+
broadcasted_denamed(a::AbstractITensor, inds) = denamed(a, inds)
99103
function broadcasted_denamed(bc::Broadcasted, inds)
100104
return broadcasted(bc.f, Base.Fix2(broadcasted_denamed, inds).(bc.args)...)
101105
end
@@ -142,7 +146,7 @@ function Base.copy(bc::Broadcasted{<:AbstractITensorStyle})
142146
return nameddimstype(bc.style)(dest_denamed, inds_dest)
143147
end
144148

145-
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractITensorStyle})
149+
function Base.copyto!(dest::AbstractITensor, bc::Broadcasted{<:AbstractITensorStyle})
146150
dest_denamed = denamed(dest)
147151
inds_dest = inds(dest)
148152
bc_denamed = broadcasted_denamed(bc, inds_dest)

0 commit comments

Comments
 (0)