1+ abstract type AbstractNamedArrayPartition{T, A, NT} <: AbstractVector{T} end
2+
13"""
24 NamedArrayPartition(; kwargs...)
35 NamedArrayPartition(x::NamedTuple)
@@ -6,137 +8,155 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the
68constructor-specified names. However, unlike `ArrayPartition`, each individual array
79must have the same element type.
810"""
9- struct NamedArrayPartition{T, A <: ArrayPartition{T} , NT <: NamedTuple } <: AbstractVector{T }
11+ struct NamedArrayPartition{T, A <: ArrayPartition{T} , NT <: NamedTuple } <: AbstractNamedArrayPartition{T, A, NT }
1012 array_partition:: A
1113 names_to_indices:: NT
1214end
13- NamedArrayPartition ( ; kwargs... ) = NamedArrayPartition (NamedTuple (kwargs))
14- function NamedArrayPartition ( x:: NamedTuple )
15+ ( :: Type{T} )( ; kwargs... ) where {T <: AbstractNamedArrayPartition } = T (NamedTuple (kwargs))
16+ function ( :: Type{T} )( x:: NamedTuple ) where {T <: AbstractNamedArrayPartition }
1517 names_to_indices = NamedTuple (Pair (symbol, index)
1618 for (index, symbol) in enumerate (keys (x)))
1719
1820 # enforce homogeneity of eltypes
1921 @assert all (eltype .(values (x)) .== eltype (first (x)))
20- T = eltype (first (x))
22+ R = eltype (first (x))
2123 S = typeof (values (x))
22- return NamedArrayPartition (ArrayPartition {T, S} (values (x)), names_to_indices)
24+ return T (ArrayPartition {R, S} (values (x)), names_to_indices)
25+ end
26+
27+ function named_partition_constructor (X:: T ) where {T<: AbstractNamedArrayPartition }
28+ getfield (parentmodule (T), nameof (T))
2329end
2430
2531# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
2632# fields except through `getfield` and accessor functions.
27- ArrayPartition (x:: NamedArrayPartition ) = getfield (x, :array_partition )
33+ ArrayPartition (x:: AbstractNamedArrayPartition ) = getfield (x, :array_partition )
2834
29- function Base. similar (A:: NamedArrayPartition )
30- NamedArrayPartition (
35+ # With new type structure this function does the same as Base.similar(x::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
36+ #= function Base.similar(A::T) where {T<:AbstractNamedArrayPartition}
37+ Tconstr = named_partition_constructor(A)
38+ Tconstr(
3139 similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
32- end
40+ end =#
3341
3442# return ArrayPartition when possible, otherwise next best thing of the correct size
35- function Base. similar (A:: NamedArrayPartition , dims:: NTuple{N, Int} ) where {N}
36- NamedArrayPartition (
43+ function Base. similar (A:: T , dims:: NTuple{N, Int} ) where {T<: AbstractNamedArrayPartition , N}
44+ Tconstr = named_partition_constructor (A)
45+ Tconstr (
3746 similar (getfield (A, :array_partition ), dims), getfield (A, :names_to_indices ))
3847end
3948
4049# similar array partition of common type
41- @inline function Base. similar (A:: NamedArrayPartition , :: Type{T} ) where {T}
42- NamedArrayPartition (
50+ @inline function Base. similar (A:: S , :: Type{T} ) where {S<: AbstractNamedArrayPartition , T}
51+ Tconstr = named_partition_constructor (A)
52+ Tconstr (
4353 similar (getfield (A, :array_partition ), T), getfield (A, :names_to_indices ))
4454end
4555
4656# return ArrayPartition when possible, otherwise next best thing of the correct size
47- function Base. similar (A:: NamedArrayPartition , :: Type{T} , dims:: NTuple{N, Int} ) where {T, N}
48- NamedArrayPartition (
57+ function Base. similar (A:: S , :: Type{T} , dims:: NTuple{N, Int} ) where {T, N, S<: AbstractNamedArrayPartition }
58+ Tconstr = named_partition_constructor (A)
59+ Tconstr (
4960 similar (getfield (A, :array_partition ), T, dims), getfield (A, :names_to_indices ))
5061end
5162
5263# similar array partition with different types
5364function Base. similar (
54- A:: NamedArrayPartition , :: Type{T} , :: Type{S} , R:: DataType... ) where {T, S}
55- NamedArrayPartition (
65+ A:: U , :: Type{T} , :: Type{S} , R:: DataType... ) where {T, S, U<: AbstractNamedArrayPartition }
66+ Tconstr = named_partition_constructor (A)
67+ Tconstr (
5668 similar (getfield (A, :array_partition ), T, S, R), getfield (A, :names_to_indices ))
5769end
5870
59- Base. Array (x:: NamedArrayPartition ) = Array (ArrayPartition (x))
71+ Base. Array (x:: AbstractNamedArrayPartition ) = Array (ArrayPartition (x))
6072
61- function Base. zero (x:: NamedArrayPartition{T, S, TN} ) where {T, S, TN }
62- NamedArrayPartition {T, S, TN} (zero (ArrayPartition (x)), getfield (x, :names_to_indices ))
73+ function Base. zero (x:: R ) where {R <: AbstractNamedArrayPartition }
74+ R (zero (ArrayPartition (x)), getfield (x, :names_to_indices ))
6375end
64- Base. zero (A:: NamedArrayPartition , dims:: NTuple{N, Int} ) where {N} = zero (A) # ignore dims since named array partitions are vectors
76+ Base. zero (A:: AbstractNamedArrayPartition , dims:: NTuple{N, Int} ) where {N} = zero (A) # ignore dims since named array partitions are vectors
6577
66- Base. propertynames (x:: NamedArrayPartition ) = propertynames (getfield (x, :names_to_indices ))
67- function Base. getproperty (x:: NamedArrayPartition , s:: Symbol )
78+ Base. propertynames (x:: AbstractNamedArrayPartition ) = propertynames (getfield (x, :names_to_indices ))
79+ function Base. getproperty (x:: AbstractNamedArrayPartition , s:: Symbol )
6880 getindex (ArrayPartition (x). x, getproperty (getfield (x, :names_to_indices ), s))
6981end
7082
7183# this enables x.s = some_array.
72- @inline function Base. setproperty! (x:: NamedArrayPartition , s:: Symbol , v)
84+ @inline function Base. setproperty! (x:: AbstractNamedArrayPartition , s:: Symbol , v)
7385 index = getproperty (getfield (x, :names_to_indices ), s)
7486 ArrayPartition (x). x[index] .= v
7587end
7688
7789# print out NamedArrayPartition as a NamedTuple
78- Base. summary (x:: NamedArrayPartition ) = string (typeof (x), " with arrays:" )
79- function Base. show (io:: IO , m:: MIME"text/plain" , x:: NamedArrayPartition )
90+ Base. summary (x:: AbstractNamedArrayPartition ) = string (typeof (x), " with arrays:" )
91+ function Base. show (io:: IO , m:: MIME"text/plain" , x:: AbstractNamedArrayPartition )
8092 show (
8193 io, m, NamedTuple (Pair .(keys (getfield (x, :names_to_indices )), ArrayPartition (x). x)))
8294end
8395
84- Base. size (x:: NamedArrayPartition ) = size (ArrayPartition (x))
85- Base. length (x:: NamedArrayPartition ) = length (ArrayPartition (x))
86- Base. getindex (x:: NamedArrayPartition , args... ) = getindex (ArrayPartition (x), args... )
96+ Base. size (x:: AbstractNamedArrayPartition ) = size (ArrayPartition (x))
97+ Base. length (x:: AbstractNamedArrayPartition ) = length (ArrayPartition (x))
98+ Base. getindex (x:: AbstractNamedArrayPartition , args... ) = getindex (ArrayPartition (x), args... )
8799
88- Base. setindex! (x:: NamedArrayPartition , args... ) = setindex! (ArrayPartition (x), args... )
89- function Base. map (f, x:: NamedArrayPartition )
90- NamedArrayPartition (map (f, ArrayPartition (x)), getfield (x, :names_to_indices ))
100+ Base. setindex! (x:: AbstractNamedArrayPartition , args... ) = setindex! (ArrayPartition (x), args... )
101+ function Base. map (f, x:: T ) where {T<: AbstractNamedArrayPartition }
102+ Tconstr = named_partition_constructor (x)
103+ Tconstr (map (f, ArrayPartition (x)), getfield (x, :names_to_indices ))
91104end
92- Base. mapreduce (f, op, x:: NamedArrayPartition ) = mapreduce (f, op, ArrayPartition (x))
93- # Base.filter(f, x::NamedArrayPartition ) = filter(f, ArrayPartition(x))
105+ Base. mapreduce (f, op, x:: AbstractNamedArrayPartition ) = mapreduce (f, op, ArrayPartition (x))
106+ # Base.filter(f, x::AbstractNamedArrayPartition ) = filter(f, ArrayPartition(x))
94107
95- function Base. similar (x:: NamedArrayPartition{T, S, NT} ) where {T, S, NT}
96- NamedArrayPartition {T, S, NT} (
97- similar (ArrayPartition (x)), getfield (x, :names_to_indices ))
98- end
108+ function Base. similar (x:: AbstractNamedArrayPartition{T, A, NT} ) where {T, A, NT}
109+ # Safely extract the concrete type parameters
99110
111+ Tconstr = named_partition_constructor (x)
112+ return Tconstr {T, A, NT} (
113+ similar (getfield (x, :array_partition )),
114+ getfield (x, :names_to_indices )
115+ )
116+ end
100117# broadcasting
101- function Base. BroadcastStyle (:: Type{<:NamedArrayPartition} )
102- Broadcast. ArrayStyle {NamedArrayPartition } ()
118+ function Base. BroadcastStyle (:: Type{T} ) where {T <: AbstractNamedArrayPartition }
119+ Broadcast. ArrayStyle {T } ()
103120end
104- function Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition }} ,
105- :: Type{ElType} ) where {ElType}
121+ function Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{T }} ,
122+ :: Type{ElType} ) where {ElType, T <: AbstractNamedArrayPartition }
106123 x = find_NamedArrayPartition (bc)
107- return NamedArrayPartition (similar (ArrayPartition (x)), getfield (x, :names_to_indices ))
124+ Tconstr = named_partition_constructor (x)
125+ return Tconstr (similar (ArrayPartition (x)), getfield (x, :names_to_indices ))
108126end
109127
110128# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
111129function Base. BroadcastStyle (
112- :: Broadcast.ArrayStyle{NamedArrayPartition } , :: Broadcast.DefaultArrayStyle{1} )
130+ :: Broadcast.ArrayStyle{<:AbstractNamedArrayPartition } , :: Broadcast.DefaultArrayStyle{1} )
113131 Broadcast. DefaultArrayStyle {1} ()
114132end
115133
116134# hook into ArrayPartition broadcasting routines
117- @inline RecursiveArrayTools. npartitions (x:: NamedArrayPartition ) = npartitions (ArrayPartition (x))
118- @inline RecursiveArrayTools. unpack (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition }} , i) = Broadcast. Broadcasted (
135+ @inline RecursiveArrayTools. npartitions (x:: AbstractNamedArrayPartition ) = npartitions (ArrayPartition (x))
136+ @inline RecursiveArrayTools. unpack (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition }} , i) = Broadcast. Broadcasted (
119137 bc. f, RecursiveArrayTools. unpack_args (i, bc. args))
120- @inline RecursiveArrayTools. unpack (x:: NamedArrayPartition , i) = unpack (ArrayPartition (x), i)
138+ @inline RecursiveArrayTools. unpack (x:: AbstractNamedArrayPartition , i) = unpack (ArrayPartition (x), i)
121139
122- function Base. copy (A:: NamedArrayPartition{T, S, NT} ) where {T, S, NT}
123- NamedArrayPartition {T, S, NT} (copy (ArrayPartition (A)), getfield (A, :names_to_indices ))
140+ function Base. copy (A:: AbstractNamedArrayPartition{T, S, NT} ) where {T, S, NT}
141+ Tconstr = named_partition_constructor (A)
142+ Tconstr {T, S, NT} (copy (ArrayPartition (A)), getfield (A, :names_to_indices ))
124143end
125144
126- @inline NamedArrayPartition ( f:: F , N, names_to_indices) where {F <: Function } = NamedArrayPartition (
145+ @inline ( :: Type{T} )( f:: F , N, names_to_indices) where {F <: Function , T <: AbstractNamedArrayPartition } = T (
127146 ArrayPartition (ntuple (f, Val (N))), names_to_indices)
128147
129- @inline function Base. copy (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition }} )
148+ @inline function Base. copy (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{T }} ) where {T <: AbstractNamedArrayPartition }
130149 N = npartitions (bc)
131150 @inline function f (i)
132151 copy (unpack (bc, i))
133152 end
134153 x = find_NamedArrayPartition (bc)
135- NamedArrayPartition (f, N, getfield (x, :names_to_indices ))
154+ Tconstr = named_partition_constructor (x)
155+ Tconstr (f, N, getfield (x, :names_to_indices ))
136156end
137157
138- @inline function Base. copyto! (dest:: NamedArrayPartition ,
139- bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition }} )
158+ @inline function Base. copyto! (dest:: AbstractNamedArrayPartition ,
159+ bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition }} )
140160 N = npartitions (dest, bc)
141161 @inline function f (i)
142162 copyto! (ArrayPartition (dest). x[i], unpack (bc, i))
146166end
147167
148168# Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
149- function ArrayInterface. zeromatrix (A:: NamedArrayPartition )
169+ function ArrayInterface. zeromatrix (A:: AbstractNamedArrayPartition )
150170 B = ArrayPartition (A)
151171 x = reduce (vcat,vec .(B. x))
152172 x .* x' .* false
@@ -159,5 +179,5 @@ function find_NamedArrayPartition(args::Tuple)
159179end
160180find_NamedArrayPartition (x) = x
161181find_NamedArrayPartition (:: Tuple{} ) = nothing
162- find_NamedArrayPartition (x:: NamedArrayPartition , rest) = x
182+ find_NamedArrayPartition (x:: AbstractNamedArrayPartition , rest) = x
163183find_NamedArrayPartition (:: Any , rest) = find_NamedArrayPartition (rest)
0 commit comments