Skip to content

Commit 8b5c42c

Browse files
committed
More functionality and tests
1 parent aed6887 commit 8b5c42c

6 files changed

Lines changed: 335 additions & 9 deletions

File tree

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
module SerializedArraysLinearAlgebraExt
22

33
using LinearAlgebra: LinearAlgebra, mul!
4-
using SerializedArrays: SerializedArray
4+
using SerializedArrays: AbstractSerializedMatrix
55

66
function LinearAlgebra.mul!(
7-
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
7+
a_dest::AbstractMatrix,
8+
a1::AbstractSerializedMatrix,
9+
a2::AbstractSerializedMatrix,
10+
α::Number,
11+
β::Number,
812
)
913
mul!(a_dest, copy(a1), copy(a2), α, β)
1014
return a_dest
1115
end
1216

17+
for f in [:eigen, :qr, :svd]
18+
@eval begin
19+
function LinearAlgebra.$f(a::AbstractSerializedMatrix; kwargs...)
20+
return LinearAlgebra.$f(copy(a))
21+
end
22+
end
23+
end
24+
1325
end

src/SerializedArrays.jl

Lines changed: 195 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,69 @@
11
module SerializedArrays
22

3+
using Base.PermutedDimsArrays: genperm
34
using ConstructionBase: constructorof
4-
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
5+
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
56
using Serialization: deserialize, serialize
67

7-
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N}
8+
abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
9+
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
10+
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
11+
12+
function _copyto_write!(dst, src)
13+
writeblock!(dst, src, axes(src)...)
14+
return dst
15+
end
16+
function _copyto_read!(dst, src)
17+
readblock!(src, dst, axes(src)...)
18+
return dst
19+
end
20+
21+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractArray)
22+
return _copyto_write!(dst, src)
23+
end
24+
function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray)
25+
return _copyto_read!(dst, src)
26+
end
27+
# Fix ambiguity error.
28+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray)
29+
return copyto!(dst, copy(src))
30+
end
31+
# Fix ambiguity error.
32+
function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray)
33+
return copyto!(dst, copy(src))
34+
end
35+
# Fix ambiguity error.
36+
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray)
37+
return _copyto_write!(dst, src)
38+
end
39+
# Fix ambiguity error.
40+
function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray)
41+
return _copyto_read!(dst, src)
42+
end
43+
44+
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray)
45+
return copy(a1) == copy(a2)
46+
end
47+
function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray)
48+
return a1 == copy(a2)
49+
end
50+
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
51+
return copy(a1) == a2
52+
end
53+
54+
# # These cause too many ambiguity errors, try bringing them back.
55+
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
56+
# return arrayt(a)
57+
# end
58+
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
59+
# return convert(arrayt, copy(a))
60+
# end
61+
# # Fixes ambiguity error.
62+
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
63+
# return convert(arrayt, copy(a))
64+
# end
65+
66+
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N}
867
file::String
968
axes::Axes
1069
end
@@ -13,7 +72,7 @@ Base.axes(a::SerializedArray) = getfield(a, :axes)
1372
arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A
1473

1574
function SerializedArray(file::String, a::AbstractArray)
16-
serialize(file, vec(a))
75+
serialize(file, a)
1776
ax = axes(a)
1877
return SerializedArray{eltype(a),ndims(a),typeof(a),typeof(ax)}(file, ax)
1978
end
@@ -29,13 +88,18 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
2988
return constructorof(arraytype(a)){elt}(undef, dims...)
3089
end
3190

91+
function materialize(a::SerializedArray)
92+
return deserialize(file(a))::arraytype(a)
93+
end
3294
function Base.copy(a::SerializedArray)
33-
arrayt = arraytype(a)
34-
return convert(arrayt, reshape(deserialize(file(a)), axes(a)))::arrayt
95+
return materialize(a)
3596
end
3697

3798
Base.size(a::SerializedArray) = length.(axes(a))
3899

100+
to_axis(r::AbstractUnitRange) = r
101+
to_axis(d::Integer) = Base.OneTo(d)
102+
39103
#
40104
# DiskArrays
41105
#
@@ -67,6 +131,131 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz
67131
return similar(a, output_size)
68132
end
69133

134+
struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
135+
AbstractSerializedArray{T,N}
136+
permuted_parent::P
137+
end
138+
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))
139+
140+
perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
141+
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
142+
143+
iperm(a::PermutedSerializedArray) = iperm(a.permuted_parent)
144+
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
145+
146+
Base.axes(a::PermutedSerializedArray) = genperm(axes(parent(a)), perm(a))
147+
Base.size(a::PermutedSerializedArray) = length.(axes(a))
148+
149+
function PermutedSerializedArray(a::AbstractArray, perm)
150+
a′ = PermutedDimsArray(a, perm)
151+
return PermutedSerializedArray{eltype(a),ndims(a),typeof(a′)}(a′)
152+
end
153+
154+
function Base.permutedims(a::AbstractSerializedArray, perm)
155+
return PermutedSerializedArray(a, perm)
156+
end
157+
158+
function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
159+
return similar(parent(a), elt, dims)
160+
end
161+
162+
function materialize(a::PermutedSerializedArray)
163+
return PermutedDimsArray(copy(parent(a)), perm(a))
164+
end
165+
function Base.copy(a::PermutedSerializedArray)
166+
return copy(materialize(a))
167+
end
168+
169+
haschunks(a::PermutedSerializedArray) = Unchunked()
170+
function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange...)
171+
ip = iperm(a)
172+
# Permute the indices
173+
inew = genperm(i, ip)
174+
# Permute the dest block and read from the true parent
175+
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
176+
return nothing
177+
end
178+
function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...)
179+
ip = iperm(a)
180+
inew = genperm(i, ip)
181+
# Permute the dest block and write from the true parent
182+
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
183+
return nothing
184+
end
185+
186+
struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N}
187+
parent::P
188+
axes::Axes
189+
end
190+
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
191+
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)
192+
193+
function ReshapedSerializedArray(
194+
a::AbstractSerializedArray,
195+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
196+
)
197+
return ReshapedSerializedArray{eltype(a),length(ax),typeof(a),typeof(ax)}(a, ax)
198+
end
199+
function ReshapedSerializedArray(
200+
a::AbstractSerializedArray,
201+
shape::Tuple{
202+
Union{Integer,AbstractUnitRange{<:Integer}},
203+
Vararg{Union{Integer,AbstractUnitRange{<:Integer}}},
204+
},
205+
)
206+
return ReshapedSerializedArray(a, to_axis.(shape))
207+
end
208+
209+
Base.size(a::ReshapedSerializedArray) = length.(axes(a))
210+
211+
function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
212+
return similar(parent(a), elt, dims)
213+
end
214+
215+
function materialize(a::ReshapedSerializedArray)
216+
return reshape(materialize(parent(a)), axes(a))
217+
end
218+
function Base.copy(a::ReshapedSerializedArray)
219+
a′ = materialize(a)
220+
return a′ isa Base.ReshapedArray ? copy(a′) : a′
221+
end
222+
223+
# Special case for handling nested wrappers that aren't
224+
# friendly on GPU. Consider special cases of strded arrays
225+
# and handle with stride manipulations.
226+
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
227+
a′ = reshape(copy(parent(a)), axes(a))
228+
return a′ isa Base.ReshapedArray ? copy(a′) : a′
229+
end
230+
231+
function Base.reshape(a::AbstractSerializedArray, dims::Tuple{Int,Vararg{Int}})
232+
return ReshapedSerializedArray(a, dims)
233+
end
234+
235+
DiskArrays.haschunks(a::ReshapedSerializedArray) = Unchunked()
236+
function DiskArrays.readblock!(
237+
a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
238+
) where {N}
239+
if i == axes(a)
240+
aout .= copy(a)
241+
return a
242+
end
243+
aout .= @view copy(a)[i...]
244+
return nothing
245+
end
246+
function DiskArrays.writeblock!(
247+
a::ReshapedSerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N}
248+
) where {N}
249+
if i == axes(a)
250+
serialize(file(a), ain)
251+
return a
252+
end
253+
a′ = copy(a)
254+
a′[i...] = ain
255+
serialize(file(a), a′)
256+
return nothing
257+
end
258+
70259
#
71260
# Broadcast
72261
#
@@ -89,7 +278,7 @@ function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N})
89278
end
90279

91280
struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <:
92-
AbstractDiskArray{T,N}
281+
AbstractSerializedArray{T,N}
93282
broadcasted::BC
94283
end
95284
function BroadcastSerializedArray(

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
34
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
45
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/test_adaptext.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Adapt: adapt
2+
using JLArrays: JLArray
3+
using SerializedArrays: SerializedArray
4+
using StableRNGs: StableRNG
5+
using Test: @test, @testset
6+
using TestExtras: @constinferred
7+
8+
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9+
arrayts = (Array, JLArray)
10+
@testset "SerializedArraysAdaptExt (eltype=$elt, arraytype=$arrayt)" for elt in elts,
11+
arrayt in arrayts
12+
13+
rng = StableRNG(123)
14+
x = arrayt(randn(rng, elt, 4, 4))
15+
y = PermutedDimsArray(x, (2, 1))
16+
a = adapt(SerializedArray, x)
17+
@test a isa SerializedArray{elt,2,arrayt{elt,2}}
18+
b = adapt(SerializedArray, y)
19+
@test b isa
20+
PermutedDimsArray{elt,2,(2, 1),(2, 1),<:SerializedArray{elt,2,<:arrayt{elt,2}}}
21+
@test parent(b) == a
22+
end

test/test_basics.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using GPUArraysCore: @allowscalar
22
using JLArrays: JLArray
3-
using SerializedArrays: SerializedArray
3+
using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray
44
using StableRNGs: StableRNG
55
using Test: @test, @testset
66
using TestExtras: @constinferred
@@ -36,4 +36,64 @@ arrayts = (Array, JLArray)
3636
c = @constinferred(a * b)
3737
@test c == x * y
3838
@test c isa arrayt{elt,2}
39+
40+
rng = StableRNG(123)
41+
x = arrayt(randn(rng, elt, 4, 4))
42+
a = SerializedArray(x)
43+
b = similar(a)
44+
@test b isa arrayt{elt,2}
45+
@test size(b) == size(a) == size(x)
46+
47+
rng = StableRNG(123)
48+
x = arrayt(randn(rng, elt, 4, 4))
49+
a = permutedims(SerializedArray(x), (2, 1))
50+
@test a isa PermutedSerializedArray{elt,2}
51+
@test similar(a) isa arrayt{elt,2}
52+
@test copy(a) == permutedims(x, (2, 1))
53+
54+
rng = StableRNG(123)
55+
x = arrayt(randn(rng, elt, 4, 4))
56+
a = reshape(SerializedArray(x), 16)
57+
@test a isa ReshapedSerializedArray{elt,1}
58+
@test similar(a) isa arrayt{elt,1}
59+
@test copy(a) == reshape(x, 16)
60+
61+
rng = StableRNG(123)
62+
x = arrayt(randn(rng, elt, 4, 4))
63+
a = reshape(permutedims(SerializedArray(x), (2, 1)), 16)
64+
@test a isa ReshapedSerializedArray{elt,1,<:PermutedSerializedArray{elt,2}}
65+
@test similar(a) isa arrayt{elt,1}
66+
@test copy(a) == reshape(permutedims(x, (2, 1)), 16)
67+
68+
rng = StableRNG(123)
69+
x = arrayt(randn(rng, elt, 4, 4))
70+
a = SerializedArray(x)
71+
@test a == a
72+
@test x == a
73+
@test a == x
74+
75+
rng = StableRNG(123)
76+
x = arrayt(randn(rng, elt, 4, 4))
77+
y = arrayt(randn(rng, elt, 4, 4))
78+
a = SerializedArray(x)
79+
b = SerializedArray(y)
80+
copyto!(b, a)
81+
@test b == a
82+
@test b == x
83+
84+
rng = StableRNG(123)
85+
x = arrayt(randn(rng, elt, 4, 4))
86+
y = arrayt(randn(rng, elt, 4, 4))
87+
a = SerializedArray(x)
88+
b = SerializedArray(y)
89+
copyto!(b, x)
90+
@test b == a
91+
92+
rng = StableRNG(123)
93+
x = arrayt(randn(rng, elt, 4, 4))
94+
y = arrayt(randn(rng, elt, 4, 4))
95+
a = SerializedArray(x)
96+
copyto!(y, a)
97+
b = SerializedArray(y)
98+
@test b == a
3999
end

0 commit comments

Comments
 (0)