Skip to content

Commit cd6ac32

Browse files
committed
Made BatchedCuArray, SharedCuArray, SharedScalar and BatchedStruct Adapt.jl compatible
1 parent a63c5e1 commit cd6ac32

3 files changed

Lines changed: 88 additions & 37 deletions

File tree

GeneralisedFilters/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ authors = ["THargreaves <tim.hargreaves@icloud.com>", "Charles Knipp <charleskni
66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
9+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1011
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1112
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -25,6 +26,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2526
[compat]
2627
AbstractMCMC = "5"
2728
AcceleratedKernels = "0.3, 0.4"
29+
Adapt = "4.5.0"
2830
Aqua = "0.8"
2931
CUDA = "5"
3032
DataStructures = "0.18.20, 0.19"

GeneralisedFilters/src/batching/operations.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ import LinearAlgebra: norm
66
# =============================================================================
77

88
# Type aliases for BatchedStruct-wrapped matrices
9-
const BatchedAdjoint{T,M} = BatchedStruct{
10-
Adjoint{T,CuArray{T,2,M}},@NamedTuple{parent::BatchedCuMatrix{T,M}}
9+
const BatchedAdjoint{T,A<:AbstractArray{T,3}} = BatchedStruct{
10+
<:Adjoint{T,<:AbstractArray{T,2}},@NamedTuple{parent::BatchedCuMatrix{T,A}}
1111
}
12-
const BatchedTranspose{T,M} = BatchedStruct{
13-
Transpose{T,CuArray{T,2,M}},@NamedTuple{parent::BatchedCuMatrix{T,M}}
12+
const BatchedTranspose{T,A<:AbstractArray{T,3}} = BatchedStruct{
13+
<:Transpose{T,<:AbstractArray{T,2}},@NamedTuple{parent::BatchedCuMatrix{T,A}}
1414
}
15-
const SharedAdjoint{T,M} = BatchedStruct{
16-
Adjoint{T,CuArray{T,2,M}},@NamedTuple{parent::SharedCuMatrix{T,M}}
15+
const SharedAdjoint{T,A<:AbstractArray{T,2}} = BatchedStruct{
16+
<:Adjoint{T,<:AbstractArray{T,2}},@NamedTuple{parent::SharedCuMatrix{T,A}}
1717
}
18-
const SharedTranspose{T,M} = BatchedStruct{
19-
Transpose{T,CuArray{T,2,M}},@NamedTuple{parent::SharedCuMatrix{T,M}}
18+
const SharedTranspose{T,A<:AbstractArray{T,2}} = BatchedStruct{
19+
<:Transpose{T,<:AbstractArray{T,2}},@NamedTuple{parent::SharedCuMatrix{T,A}}
2020
}
2121

2222
# Union of all GEMM-compatible matrix types

GeneralisedFilters/src/batching/types.jl

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Adapt
12
using CUDA
23
using LinearAlgebra:
34
Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky
@@ -13,45 +14,50 @@ export BatchedStruct
1314
# =============================================================================
1415

1516
"""
16-
BatchedCuArray{T, NE, NB, NT, M} <: AbstractArray{CuArray{T,NE,M}, NB}
17+
BatchedCuArray{T,NE,NB,NT,A<:AbstractArray{T,NT}} <: AbstractArray{Any,NB}
1718
18-
An `NB`-dimensional batch of `NE`-dimensional CuArrays, stored as a single contiguous
19-
`CuArray{T, NT, M}` where `NT = NE + NB`.
19+
An `NB`-dimensional batch of `NE`-dimensional arrays, stored as a single contiguous
20+
`NT`-dimensional array `data` where `NT = NE + NB`.
2021
2122
- `NE`: number of element dimensions (the "inner" array shape)
2223
- `NB`: number of batch dimensions
2324
- `NT`: total number of dimensions (`NE + NB`); required explicitly because Julia's type
2425
system cannot express arithmetic on type parameters
26+
- `A`: storage array type
2527
2628
The first `NE` dimensions index within each element; the last `NB` dimensions index across
2729
the batch.
2830
31+
This type is generic over the storage array type so that it can participate in `Adapt.jl`
32+
transformations. In the user-facing intended usage, `data` is a `CuArray{T, NT, M}`.
33+
2934
# Common aliases
30-
- `BatchedCuMatrix{T,M}` = `BatchedCuArray{T,2,1,3,M}` — a vector of matrices
31-
- `BatchedCuVector{T,M}` = `BatchedCuArray{T,1,1,2,M}` — a vector of vectors
35+
- `BatchedCuMatrix{T,A}` = `BatchedCuArray{T,2,1,3,A}` — a vector of matrices
36+
- `BatchedCuVector{T,A}` = `BatchedCuArray{T,1,1,2,A}` — a vector of vectors
3237
"""
33-
struct BatchedCuArray{T,NE,NB,NT,M} <: AbstractArray{CuArray{T,NE,M},NB}
34-
data::CuArray{T,NT,M}
38+
struct BatchedCuArray{T,NE,NB,NT,A<:AbstractArray{T,NT}} <: AbstractArray{Any,NB}
39+
data::A
3540

36-
function BatchedCuArray{T,NE,NB,NT,M}(data::CuArray{T,NT,M}) where {T,NE,NB,NT,M}
41+
function BatchedCuArray{T,NE,NB,NT,A}(data::A) where {T,NE,NB,NT,A<:AbstractArray{T,NT}}
3742
NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)")
38-
return new{T,NE,NB,NT,M}(data)
43+
return new{T,NE,NB,NT,A}(data)
3944
end
4045
end
4146

4247
# Convenience constructor: infer T and M, require explicit NE and NB
43-
function BatchedCuArray{T,NE,NB}(data::CuArray{T,NT,M}) where {T,NE,NB,NT,M}
48+
function BatchedCuArray{T,NE,NB}(data::A) where {T,NE,NB,A<:AbstractArray{T}}
49+
NT = ndims(data)
4450
NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)")
45-
return BatchedCuArray{T,NE,NB,NT,M}(data)
51+
return BatchedCuArray{T,NE,NB,NT,A}(data)
4652
end
4753

4854
# Common case aliases
49-
const BatchedCuMatrix{T,M} = BatchedCuArray{T,2,1,3,M}
50-
const BatchedCuVector{T,M} = BatchedCuArray{T,1,1,2,M}
55+
const BatchedCuMatrix{T,A<:AbstractArray{T,3}} = BatchedCuArray{T,2,1,3,A}
56+
const BatchedCuVector{T,A<:AbstractArray{T,2}} = BatchedCuArray{T,1,1,2,A}
5157

5258
# Constructors for aliased cases
53-
BatchedCuMatrix(data::CuArray{T,3,M}) where {T,M} = BatchedCuArray{T,2,1,3,M}(data)
54-
BatchedCuVector(data::CuArray{T,2,M}) where {T,M} = BatchedCuArray{T,1,1,2,M}(data)
59+
BatchedCuMatrix(data::A) where {T,A<:AbstractArray{T,3}} = BatchedCuArray{T,2,1,3,A}(data)
60+
BatchedCuVector(data::A) where {T,A<:AbstractArray{T,2}} = BatchedCuArray{T,1,1,2,A}(data)
5561

5662
const BatchedArray = BatchedCuArray
5763

@@ -71,47 +77,69 @@ end
7177

7278
batch_size(x::BatchedCuArray) = length(x)
7379

80+
# Adapting BatchedCuArray to bitstype
81+
function Adapt.adapt_structure(
82+
to,
83+
x::BatchedCuArray{T,NE,NB,NT,A},
84+
) where {T,NE,NB,NT,A}
85+
data_adapted = Adapt.adapt(to, x.data)
86+
return BatchedCuArray{T,NE,NB,NT,typeof(data_adapted)}(data_adapted)
87+
end
88+
7489
# =============================================================================
7590
# Shared Types (same data reused across all batch elements)
7691
# =============================================================================
7792

7893
"""
79-
SharedCuArray{T, InnerN, BatchN, M} <: AbstractArray{CuArray{T,InnerN,M}, BatchN}
94+
SharedCuArray{T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} <: AbstractArray{Any,BatchN}
8095
81-
A batch of CuArrays where every element is the same underlying `CuArray{T,InnerN,M}`.
96+
A batch of arrays where every element is the same underlying array.
8297
Unlike `Ref(array)`, this type carries an explicit batch size and satisfies the
8398
`AbstractArray` contract honestly.
8499
85100
Use `Ref(array)` when the batch size is unknown or irrelevant (e.g. during broadcast
86101
setup). Use `SharedCuArray` when you need a proper `AbstractArray` with a known size.
87102
103+
This type is generic over the storage array type so that it can participate in `Adapt.jl`
104+
transformations. In the user-facing intended usage, `data` is a `CuArray{T,InnerN,M}`.
105+
88106
# Common aliases
89-
- `SharedCuMatrix{T,M}` = `SharedCuArray{T,2,1,M}`
90-
- `SharedCuVector{T,M}` = `SharedCuArray{T,1,1,M}`
107+
- `SharedCuMatrix{T,A}` = `SharedCuArray{T,2,1,A}`
108+
- `SharedCuVector{T,A}` = `SharedCuArray{T,1,1,A}`
91109
"""
92-
struct SharedCuArray{T,InnerN,BatchN,M} <: AbstractArray{CuArray{T,InnerN,M},BatchN}
93-
data::CuArray{T,InnerN,M}
110+
struct SharedCuArray{T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} <: AbstractArray{Any,BatchN}
111+
data::A
94112
batchsize::NTuple{BatchN,Int}
95113
end
96114

97115
# Outer constructor: accept a plain Int for the common 1D-batch case
98-
function SharedCuArray{T,InnerN,1,M}(data::CuArray{T,InnerN,M}, N::Int) where {T,InnerN,M}
99-
return SharedCuArray{T,InnerN,1,M}(data, (N,))
116+
function SharedCuArray{T,InnerN,1,A}(data::A, N::Int) where {T,InnerN,A<:AbstractArray{T,InnerN}}
117+
return SharedCuArray{T,InnerN,1,A}(data, (N,))
100118
end
101119

102-
const SharedCuMatrix{T,M} = SharedCuArray{T,2,1,M}
103-
const SharedCuVector{T,M} = SharedCuArray{T,1,1,M}
120+
const SharedCuMatrix{T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}
121+
const SharedCuVector{T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}
122+
123+
# Constructors for aliased cases
124+
SharedCuMatrix(data::A, N::Int) where {T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}(data, N)
125+
SharedCuVector(data::A, N::Int) where {T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}(data, N)
104126

105127
const SharedArray = SharedCuArray
106128

129+
Base.eltype(::Type{<:BatchedCuArray{T,NE}}) where {T,NE} = AbstractArray{T,NE}
130+
Base.eltype(::Type{<:SharedCuArray{T,InnerN}}) where {T,InnerN} = AbstractArray{T,InnerN}
131+
107132
"""
108-
Shared(data::CuArray, N::Int) -> SharedCuArray
133+
Shared(data::AbstractArray, N::Int) -> SharedCuArray
109134
110-
Convenience constructor: create a `SharedCuArray` from a CuArray with an explicit
135+
Convenience constructor: create a `SharedCuArray` from an arrat with an explicit
111136
1D batch size `N`.
137+
138+
The underlying storage is generic to support `Adapt.jl` transformations, but in
139+
the user-facing intended interface `A` is type `CuArray`
112140
"""
113-
Shared(x::CuArray{T,2,M}, N::Int) where {T,M} = SharedCuArray{T,2,1,M}(x, (N,))
114-
Shared(x::CuArray{T,1,M}, N::Int) where {T,M} = SharedCuArray{T,1,1,M}(x, (N,))
141+
Shared(x::A, N::Int) where {T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}(x, (N,))
142+
Shared(x::A, N::Int) where {T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}(x, (N,))
115143

116144
Base.IndexStyle(::Type{<:SharedCuArray}) = Base.IndexCartesian()
117145

@@ -129,6 +157,15 @@ end
129157

130158
batch_size(x::SharedCuArray) = length(x)
131159

160+
# Adapting SharedCuArray to bitstype
161+
function Adapt.adapt_structure(
162+
to,
163+
x::SharedCuArray{T,InnerN,BatchN,A},
164+
) where {T,InnerN,BatchN,A<:AbstractArray{T,InnerN}}
165+
data_adapted = Adapt.adapt(to, x.data)
166+
return SharedCuArray{T,InnerN,BatchN,typeof(data_adapted)}(data_adapted, x.batchsize)
167+
end
168+
132169
# =============================================================================
133170
# SharedScalar: a scalar value shared across all batch elements
134171
# =============================================================================
@@ -149,6 +186,9 @@ Base.:(==)(x::SharedScalar, y) = x.value == y
149186
Base.:(==)(x, y::SharedScalar) = x == y.value
150187
Base.:(==)(x::SharedScalar, y::SharedScalar) = x.value == y.value
151188

189+
# Adapting SharedScalar to bitstype
190+
Adapt.@adapt_structure SharedScalar
191+
152192
# =============================================================================
153193
# BatchedStruct - Custom wrapper for batched composite types
154194
# =============================================================================
@@ -271,6 +311,15 @@ function Base.show(io::IO, ::MIME"text/plain", x::BatchedStruct{T}) where {T}
271311
end
272312
end
273313

314+
# Adapting BatchedStruct to bitstype
315+
function Adapt.adapt_structure(
316+
to,
317+
x::BatchedStruct{T,C},
318+
) where {T,C<:NamedTuple}
319+
comps_adapted = Adapt.adapt(to, x.components)
320+
return BatchedStruct{T}(comps_adapted)
321+
end
322+
274323
# =============================================================================
275324
# Union Types for Dispatch
276325
# =============================================================================

0 commit comments

Comments
 (0)