@@ -9,11 +9,12 @@ using TypeParameterAccessors: unspecify_type_parameters
99# https://github.com/mcabbott/NamedPlus.jl
1010# https://pytorch.org/docs/stable/named_tensor.html
1111
12- abstract type AbstractITensor{DimName} <: AbstractArray{Any, Any} end
12+ abstract type AbstractITensor{DimName} end
1313
14- # Rank and element type live in the data, not the type. The `<: AbstractArray`
15- # subtyping is kept for now to reuse generic array functionality, with `Any` for
16- # both parameters since neither is fixed by the type.
14+ # Rank and element type live in the data, not the type, so the type-level `ndims`
15+ # is `Any` (like `eltype(Array)`). `AbstractITensor` is not an `AbstractArray`: the
16+ # array-like surface it needs (indexing, broadcasting, arithmetic, iteration) is
17+ # supplied directly below rather than inherited.
1718Base. ndims (:: Type{<:AbstractITensor} ) = Any
1819
1920dimnames (a:: AbstractITensor ) = throw (MethodError (dimnames, a))
@@ -53,15 +54,15 @@ denamed(a::AbstractITensor, inds) = denamed(aligneddims(a, inds))
5354dename (a:: AbstractITensor , inds) = denamed (aligndims (a, inds))
5455
5556# Output the named axes/indices of the named dims array.
56- inds (a:: AbstractArray ) = LittleSet (named .(axes (denamed (a)), dimnames (a)))
57- inds (a:: AbstractArray , dim:: Int ) = inds (a)[dim]
57+ inds (a:: AbstractITensor ) = LittleSet (named .(axes (denamed (a)), dimnames (a)))
58+ inds (a:: AbstractITensor , dim:: Int ) = inds (a)[dim]
5859
5960isnamed (:: Type{<:AbstractITensor} ) = true
6061
61- function dim (a:: AbstractArray , n)
62+ function dim (a:: AbstractITensor , n)
6263 return findfirst (== (name (n)), dimnames (a))
6364end
64- dims (a:: AbstractArray , ns) = Base. Fix1 (dim, a).(ns)
65+ dims (a:: AbstractITensor , ns) = Base. Fix1 (dim, a).(ns)
6566
6667dimname_isequal (x) = Base. Fix1 (dimname_isequal, x)
6768dimname_isequal (x, y) = isequal (x, y)
@@ -80,7 +81,7 @@ dimname_isequal(r1, r2::AbstractNamedUnitRange) = r1 == name(r2)
8081dimname_isequal (r1:: AbstractNamedUnitRange , r2:: Name ) = name (r1) == name (r2)
8182dimname_isequal (r1:: Name , r2:: AbstractNamedUnitRange ) = name (r1) == name (r2)
8283
83- function to_inds (a:: AbstractArray , dims)
84+ function to_inds (a:: AbstractITensor , dims)
8485 is = Base. Fix1 (dim, a).(name .(dims))
8586 return Base. Fix1 (inds, a).(is)
8687end
@@ -157,6 +158,11 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
157158end
158159
159160Base. copy (a:: AbstractITensor ) = nameddimsof (a, copy (denamed (a)))
161+ Base. zero (a:: AbstractITensor ) = nameddimsof (a, zero (denamed (a)))
162+
163+ # `CartesianIndices` of a named tensor is the parent's, via the named axes (as the
164+ # `AbstractArray` fallback did through `axes`).
165+ Base. CartesianIndices (a:: AbstractITensor ) = CartesianIndices (axes (a))
160166
161167# Forward `conj` to the underlying so that graded axes flip their sector arrows.
162168# The default `AbstractArray` fallback would broadcast `conj` over elements without
@@ -169,6 +175,19 @@ Base.conj(a::AbstractITensor) = nameddimsof(a, conj(denamed(a)))
169175function LinearAlgebra. normalize (a:: AbstractITensor , p:: Real = 2 )
170176 return a / LinearAlgebra. norm (a, p)
171177end
178+ function LinearAlgebra. normalize! (a:: AbstractITensor , p:: Real = 2 )
179+ LinearAlgebra. normalize! (denamed (a), p)
180+ return a
181+ end
182+
183+ # Elementwise and scalar arithmetic. `AbstractArray` routes these through
184+ # broadcasting; supply them directly now that the supertype is gone.
185+ Base.:+ (a1:: AbstractITensor , a2:: AbstractITensor ) = a1 .+ a2
186+ Base.:- (a1:: AbstractITensor , a2:: AbstractITensor ) = a1 .- a2
187+ Base.:- (a:: AbstractITensor ) = broadcast (- , a)
188+ Base.:* (a:: AbstractITensor , x:: Number ) = a .* x
189+ Base.:* (x:: Number , a:: AbstractITensor ) = x .* a
190+ Base.:/ (a:: AbstractITensor , x:: Number ) = a ./ x
172191
173192# Forward `Random.randn!` / `Random.rand!` to the concrete storage so they
174193# see the runtime eltype.
@@ -286,6 +305,26 @@ function Base.similar(a::AbstractArray, elt::Type, inds::LittleSet)
286305 return similar_nameddims (a, elt, inds)
287306end
288307
308+ # Same entry points with a named-tensor prototype. An `AbstractITensor` is no longer
309+ # an `AbstractArray`, so the methods above (which build a named tensor from a plain
310+ # array prototype) no longer cover it.
311+ function Base. similar (
312+ a:: AbstractITensor ,
313+ inds:: Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
314+ )
315+ return similar (a, eltype (a), inds)
316+ end
317+ function Base. similar (
318+ a:: AbstractITensor , elt:: Type ,
319+ inds:: Tuple{AbstractNamedUnitRange, Vararg{AbstractNamedUnitRange}}
320+ )
321+ return similar_nameddims (a, elt, inds)
322+ end
323+ Base. similar (a:: AbstractITensor , inds:: LittleSet ) = similar_nameddims (a, eltype (a), inds)
324+ function Base. similar (a:: AbstractITensor , elt:: Type , inds:: LittleSet )
325+ return similar_nameddims (a, elt, inds)
326+ end
327+
289328function setinds (a:: AbstractITensor , inds)
290329 return nameddimsconstructorof (a)(denamed (a), inds)
291330end
@@ -320,6 +359,18 @@ Base.isempty(a::AbstractITensor) = isempty(denamed(a))
320359Base. IndexStyle (a:: AbstractITensor ) = IndexStyle (denamed (a))
321360Base. eachindex (a:: AbstractITensor ) = eachindex (denamed (a))
322361
362+ # Iteration, keys, and pairs forward to the parent (these were previously inherited
363+ # from `AbstractArray`).
364+ Base. iterate (a:: AbstractITensor , state... ) = iterate (denamed (a), state... )
365+ Base. keys (a:: AbstractITensor ) = keys (denamed (a))
366+ Base. pairs (a:: AbstractITensor ) = pairs (denamed (a))
367+
368+ # Multi-argument `eachindex` dispatches on the named index style, as the
369+ # `AbstractArray` version did.
370+ function Base. eachindex (a1:: AbstractITensor , a_rest:: AbstractITensor... )
371+ return eachindex (IndexStyle (a1, a_rest... ), a1, a_rest... )
372+ end
373+
323374# Cartesian indices with names attached.
324375struct NamedIndexCartesian <: IndexStyle end
325376
@@ -387,7 +438,21 @@ function denamed(I::NamedDimsCartesianIndices)
387438 return CartesianIndices (denamed .(I. indices))
388439end
389440
390- function Base. eachindex (:: NamedIndexCartesian , a1:: AbstractArray , a_rest:: AbstractArray... )
441+ # Iterating yields `NamedDimsCartesianIndex`es. The generic `AbstractITensor`
442+ # iteration forwards to `denamed`, which here is a plain `CartesianIndices`, so
443+ # convert each parent index back through `getindex`.
444+ function Base. iterate (I:: NamedDimsCartesianIndices , state... )
445+ y = iterate (denamed (I), state... )
446+ isnothing (y) && return nothing
447+ cartesian, next_state = y
448+ return I[Tuple (cartesian)... ], next_state
449+ end
450+
451+ function Base. eachindex (
452+ :: NamedIndexCartesian ,
453+ a1:: AbstractITensor ,
454+ a_rest:: AbstractITensor...
455+ )
391456 all (a -> issetequal (inds (a1), inds (a)), a_rest) ||
392457 throw (NameMismatch (" Dimension name mismatch $(inds .((a1, a_rest... ))) ." ))
393458 # TODO : Check the shapes match.
@@ -409,6 +474,12 @@ function Base.:(==)(a1::AbstractITensor, a2::AbstractITensor)
409474 return denamed (a1) == denamed (a2, inds (a1))
410475end
411476
477+ # Base version ignores dimension names.
478+ function Base. isapprox (a1:: AbstractITensor , a2:: AbstractITensor ; kwargs... )
479+ (inds (a1) == inds (a2)) || return false
480+ return isapprox (denamed (a1), denamed (a2, inds (a1)); kwargs... )
481+ end
482+
412483# Generalization of `Base.sort` to Tuples for Julia v1.10 compatibility.
413484# TODO : Remove when we drop support for Julia v1.10.
414485_sort (x; kwargs... ) = sort (x; kwargs... )
@@ -641,7 +712,7 @@ function Base.view(a::AbstractITensor, I::ViewIndex...)
641712 return view_nameddims (a, I... )
642713end
643714
644- function getindex_nameddims (a:: AbstractArray , I... )
715+ function getindex_nameddims (a:: AbstractITensor , I... )
645716 return copy (view (a, I... ))
646717end
647718
686757
687758# Permute/align dimensions
688759
689- function aligndims (a:: AbstractArray , dims)
760+ function aligndims (a:: AbstractITensor , dims)
690761 new_dimnames = name .(dims)
691762 perm = Tuple (getperm (dimnames (a), new_dimnames))
692763 isperm (perm) || throw (
@@ -697,7 +768,7 @@ function aligndims(a::AbstractArray, dims)
697768 return nameddimsconstructorof (a)(permutedims (denamed (a), perm), new_dimnames)
698769end
699770
700- function aligneddims (a:: AbstractArray , dims)
771+ function aligneddims (a:: AbstractITensor , dims)
701772 new_dimnames = name .(dims)
702773 perm = getperm (dimnames (a), new_dimnames)
703774 isperm (perm) || throw (
0 commit comments