1+ using Adapt
12using CUDA
23using LinearAlgebra:
34 Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky
@@ -13,45 +14,50 @@ export BatchedStruct
1314# =============================================================================
1415
1516"""
16- BatchedCuArray{T, NE, NB, NT, M} <: AbstractArray{CuArray{T,NE,M}, NB}
17+ BatchedCuArray{T,NE,NB,NT,A<: AbstractArray{T,NT}} <: AbstractArray{Any, NB}
1718
18- An `NB`-dimensional batch of `NE`-dimensional CuArrays , stored as a single contiguous
19- `CuArray{T, NT, M} ` where `NT = NE + NB`.
19+ An `NB`-dimensional batch of `NE`-dimensional arrays , stored as a single contiguous
20+ `NT`-dimensional array `data ` where `NT = NE + NB`.
2021
2122- `NE`: number of element dimensions (the "inner" array shape)
2223- `NB`: number of batch dimensions
2324- `NT`: total number of dimensions (`NE + NB`); required explicitly because Julia's type
2425 system cannot express arithmetic on type parameters
26+ - `A`: storage array type
2527
2628The first `NE` dimensions index within each element; the last `NB` dimensions index across
2729the batch.
2830
31+ This type is generic over the storage array type so that it can participate in `Adapt.jl`
32+ transformations. In the user-facing intended usage, `data` is a `CuArray{T, NT, M}`.
33+
2934# Common aliases
30- - `BatchedCuMatrix{T,M }` = `BatchedCuArray{T,2,1,3,M }` — a vector of matrices
31- - `BatchedCuVector{T,M }` = `BatchedCuArray{T,1,1,2,M }` — a vector of vectors
35+ - `BatchedCuMatrix{T,A }` = `BatchedCuArray{T,2,1,3,A }` — a vector of matrices
36+ - `BatchedCuVector{T,A }` = `BatchedCuArray{T,1,1,2,A }` — a vector of vectors
3237"""
33- struct BatchedCuArray{T,NE,NB,NT,M} <: AbstractArray{CuArray{T,NE,M} ,NB}
34- data:: CuArray{T,NT,M}
38+ struct BatchedCuArray{T,NE,NB,NT,A <: AbstractArray{T,NT} } <: AbstractArray{Any ,NB}
39+ data:: A
3540
36- function BatchedCuArray {T,NE,NB,NT,M } (data:: CuArray{T,NT,M} ) where {T,NE,NB,NT,M }
41+ function BatchedCuArray {T,NE,NB,NT,A } (data:: A ) where {T,NE,NB,NT,A <: AbstractArray{T,NT} }
3742 NE + NB == NT || error (" NE ($NE ) + NB ($NB ) must equal ndims(data) ($NT )" )
38- return new {T,NE,NB,NT,M } (data)
43+ return new {T,NE,NB,NT,A } (data)
3944 end
4045end
4146
4247# Convenience constructor: infer T and M, require explicit NE and NB
43- function BatchedCuArray {T,NE,NB} (data:: CuArray{T,NT,M} ) where {T,NE,NB,NT,M}
48+ function BatchedCuArray {T,NE,NB} (data:: A ) where {T,NE,NB,A<: AbstractArray{T} }
49+ NT = ndims (data)
4450 NE + NB == NT || error (" NE ($NE ) + NB ($NB ) must equal ndims(data) ($NT )" )
45- return BatchedCuArray {T,NE,NB,NT,M } (data)
51+ return BatchedCuArray {T,NE,NB,NT,A } (data)
4652end
4753
4854# Common case aliases
49- const BatchedCuMatrix{T,M} = BatchedCuArray{T,2 ,1 ,3 ,M }
50- const BatchedCuVector{T,M} = BatchedCuArray{T,1 ,1 ,2 ,M }
55+ const BatchedCuMatrix{T,A <: AbstractArray{T,3} } = BatchedCuArray{T,2 ,1 ,3 ,A }
56+ const BatchedCuVector{T,A <: AbstractArray{T,2} } = BatchedCuArray{T,1 ,1 ,2 ,A }
5157
5258# Constructors for aliased cases
53- BatchedCuMatrix (data:: CuArray{T,3,M} ) where {T,M} = BatchedCuArray {T,2,1,3,M } (data)
54- BatchedCuVector (data:: CuArray{T,2,M} ) where {T,M} = BatchedCuArray {T,1,1,2,M } (data)
59+ BatchedCuMatrix (data:: A ) where {T,A <: AbstractArray{T,3} } = BatchedCuArray {T,2,1,3,A } (data)
60+ BatchedCuVector (data:: A ) where {T,A <: AbstractArray{T,2} } = BatchedCuArray {T,1,1,2,A } (data)
5561
5662const BatchedArray = BatchedCuArray
5763
7177
7278batch_size (x:: BatchedCuArray ) = length (x)
7379
80+ # Adapting BatchedCuArray to bitstype
81+ function Adapt. adapt_structure (
82+ to,
83+ x:: BatchedCuArray{T,NE,NB,NT,A} ,
84+ ) where {T,NE,NB,NT,A}
85+ data_adapted = Adapt. adapt (to, x. data)
86+ return BatchedCuArray {T,NE,NB,NT,typeof(data_adapted)} (data_adapted)
87+ end
88+
7489# =============================================================================
7590# Shared Types (same data reused across all batch elements)
7691# =============================================================================
7792
7893"""
79- SharedCuArray{T, InnerN, BatchN, M} <: AbstractArray{CuArray{ T,InnerN,M}, BatchN}
94+ SharedCuArray{T,InnerN,BatchN,A<: AbstractArray{T,InnerN}} <: AbstractArray{Any, BatchN}
8095
81- A batch of CuArrays where every element is the same underlying `CuArray{T,InnerN,M}` .
96+ A batch of arrays where every element is the same underlying array .
8297Unlike `Ref(array)`, this type carries an explicit batch size and satisfies the
8398`AbstractArray` contract honestly.
8499
85100Use `Ref(array)` when the batch size is unknown or irrelevant (e.g. during broadcast
86101setup). Use `SharedCuArray` when you need a proper `AbstractArray` with a known size.
87102
103+ This type is generic over the storage array type so that it can participate in `Adapt.jl`
104+ transformations. In the user-facing intended usage, `data` is a `CuArray{T,InnerN,M}`.
105+
88106# Common aliases
89- - `SharedCuMatrix{T,M }` = `SharedCuArray{T,2,1,M }`
90- - `SharedCuVector{T,M }` = `SharedCuArray{T,1,1,M }`
107+ - `SharedCuMatrix{T,A }` = `SharedCuArray{T,2,1,A }`
108+ - `SharedCuVector{T,A }` = `SharedCuArray{T,1,1,A }`
91109"""
92- struct SharedCuArray{T,InnerN,BatchN,M} <: AbstractArray{CuArray{ T,InnerN,M} ,BatchN}
93- data:: CuArray{T,InnerN,M}
110+ struct SharedCuArray{T,InnerN,BatchN,A <: AbstractArray{T,InnerN} } <: AbstractArray{Any ,BatchN}
111+ data:: A
94112 batchsize:: NTuple{BatchN,Int}
95113end
96114
97115# Outer constructor: accept a plain Int for the common 1D-batch case
98- function SharedCuArray {T,InnerN,1,M } (data:: CuArray{T,InnerN,M} , N:: Int ) where {T,InnerN,M }
99- return SharedCuArray {T,InnerN,1,M } (data, (N,))
116+ function SharedCuArray {T,InnerN,1,A } (data:: A , N:: Int ) where {T,InnerN,A <: AbstractArray{T,InnerN} }
117+ return SharedCuArray {T,InnerN,1,A } (data, (N,))
100118end
101119
102- const SharedCuMatrix{T,M} = SharedCuArray{T,2 ,1 ,M}
103- const SharedCuVector{T,M} = SharedCuArray{T,1 ,1 ,M}
120+ const SharedCuMatrix{T,A<: AbstractArray{T,2} } = SharedCuArray{T,2 ,1 ,A}
121+ const SharedCuVector{T,A<: AbstractArray{T,1} } = SharedCuArray{T,1 ,1 ,A}
122+
123+ # Constructors for aliased cases
124+ SharedCuMatrix (data:: A , N:: Int ) where {T,A<: AbstractArray{T,2} } = SharedCuArray {T,2,1,A} (data, N)
125+ SharedCuVector (data:: A , N:: Int ) where {T,A<: AbstractArray{T,1} } = SharedCuArray {T,1,1,A} (data, N)
104126
105127const SharedArray = SharedCuArray
106128
129+ Base. eltype (:: Type{<:BatchedCuArray{T,NE}} ) where {T,NE} = AbstractArray{T,NE}
130+ Base. eltype (:: Type{<:SharedCuArray{T,InnerN}} ) where {T,InnerN} = AbstractArray{T,InnerN}
131+
107132"""
108- Shared(data::CuArray , N::Int) -> SharedCuArray
133+ Shared(data::AbstractArray , N::Int) -> SharedCuArray
109134
110- Convenience constructor: create a `SharedCuArray` from a CuArray with an explicit
135+ Convenience constructor: create a `SharedCuArray` from an arrat with an explicit
1111361D batch size `N`.
137+
138+ The underlying storage is generic to support `Adapt.jl` transformations, but in
139+ the user-facing intended interface `A` is type `CuArray`
112140"""
113- Shared (x:: CuArray{T,2,M} , N:: Int ) where {T,M} = SharedCuArray {T,2,1,M } (x, (N,))
114- Shared (x:: CuArray{T,1,M} , N:: Int ) where {T,M} = SharedCuArray {T,1,1,M } (x, (N,))
141+ Shared (x:: A , N:: Int ) where {T,A <: AbstractArray{T,2} } = SharedCuArray {T,2,1,A } (x, (N,))
142+ Shared (x:: A , N:: Int ) where {T,A <: AbstractArray{T,1} } = SharedCuArray {T,1,1,A } (x, (N,))
115143
116144Base. IndexStyle (:: Type{<:SharedCuArray} ) = Base. IndexCartesian ()
117145
129157
130158batch_size (x:: SharedCuArray ) = length (x)
131159
160+ # Adapting SharedCuArray to bitstype
161+ function Adapt. adapt_structure (
162+ to,
163+ x:: SharedCuArray{T,InnerN,BatchN,A} ,
164+ ) where {T,InnerN,BatchN,A<: AbstractArray{T,InnerN} }
165+ data_adapted = Adapt. adapt (to, x. data)
166+ return SharedCuArray {T,InnerN,BatchN,typeof(data_adapted)} (data_adapted, x. batchsize)
167+ end
168+
132169# =============================================================================
133170# SharedScalar: a scalar value shared across all batch elements
134171# =============================================================================
@@ -149,6 +186,9 @@ Base.:(==)(x::SharedScalar, y) = x.value == y
149186Base.:(== )(x, y:: SharedScalar ) = x == y. value
150187Base.:(== )(x:: SharedScalar , y:: SharedScalar ) = x. value == y. value
151188
189+ # Adapting SharedScalar to bitstype
190+ Adapt. @adapt_structure SharedScalar
191+
152192# =============================================================================
153193# BatchedStruct - Custom wrapper for batched composite types
154194# =============================================================================
@@ -271,6 +311,15 @@ function Base.show(io::IO, ::MIME"text/plain", x::BatchedStruct{T}) where {T}
271311 end
272312end
273313
314+ # Adapting BatchedStruct to bitstype
315+ function Adapt. adapt_structure (
316+ to,
317+ x:: BatchedStruct{T,C} ,
318+ ) where {T,C<: NamedTuple }
319+ comps_adapted = Adapt. adapt (to, x. components)
320+ return BatchedStruct {T} (comps_adapted)
321+ end
322+
274323# =============================================================================
275324# Union Types for Dispatch
276325# =============================================================================
0 commit comments