Skip to content

Commit cc28b24

Browse files
committed
subarrays
1 parent 3631941 commit cc28b24

2 files changed

Lines changed: 67 additions & 6 deletions

File tree

src/SerializedArrays.jl

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,7 @@ Base.size(a::SerializedArray) = length.(axes(a))
100100
to_axis(r::AbstractUnitRange) = r
101101
to_axis(d::Integer) = Base.OneTo(d)
102102

103-
#
104-
# DiskArrays
105-
#
106-
103+
# DiskArrays interface
107104
DiskArrays.haschunks(::SerializedArray) = Unchunked()
108105
function DiskArrays.readblock!(
109106
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
@@ -137,6 +134,8 @@ struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
137134
end
138135
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))
139136

137+
file(a::PermutedSerializedArray) = file(parent(a))
138+
140139
perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
141140
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
142141

@@ -190,6 +189,8 @@ end
190189
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
191190
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)
192191

192+
file(a::ReshapedSerializedArray) = file(parent(a))
193+
193194
function ReshapedSerializedArray(
194195
a::AbstractSerializedArray,
195196
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
@@ -256,6 +257,50 @@ function DiskArrays.writeblock!(
256257
return nothing
257258
end
258259

260+
struct SubSerializedArray{T,N,P,I,L} <: AbstractSerializedArray{T,N}
261+
sub_parent::SubArray{T,N,P,I,L}
262+
end
263+
264+
file(a::SubSerializedArray) = file(parent(a))
265+
266+
# Base methods
267+
function Base.view(a::SerializedArray, i...)
268+
return SubSerializedArray(SubArray(a, Base.to_indices(a, i)))
269+
end
270+
Base.view(a::SubSerializedArray, i...) = SubSerializedArray(view(a.sub_parent, i...))
271+
Base.view(a::SubSerializedArray, i::CartesianIndices) = view(a, i.indices...)
272+
Base.size(a::SubSerializedArray) = size(a.sub_parent)
273+
Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
274+
Base.parent(a::SubSerializedArray) = parent(a.sub_parent)
275+
Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent)
276+
277+
function materialize(a::SubSerializedArray)
278+
return view(copy(parent(a)), parentindices(a)...)
279+
end
280+
function Base.copy(a::SubSerializedArray)
281+
return copy(materialize(a))
282+
end
283+
284+
DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
285+
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
286+
if i == axes(a)
287+
aout .= copy(a)
288+
end
289+
aout[i...] = copy(view(a, i...))
290+
return nothing
291+
end
292+
function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
293+
if i == axes(a)
294+
serialize(file(a), ain)
295+
return a
296+
end
297+
a_parent = copy(parent(a))
298+
pinds = parentindices(view(a.sub_parent, i...))
299+
a_parent[pinds...] = ain
300+
serialize(file(a), a_parent)
301+
return nothing
302+
end
303+
259304
#
260305
# Broadcast
261306
#
@@ -264,7 +309,9 @@ using Base.Broadcast:
264309
BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten
265310

266311
struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
267-
Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}()
312+
function Base.BroadcastStyle(arrayt::Type{<:AbstractSerializedArray})
313+
SerializedArrayStyle{ndims(arrayt)}()
314+
end
268315
function Base.BroadcastStyle(
269316
::SerializedArrayStyle{N}, ::SerializedArrayStyle{M}
270317
) where {N,M}

test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GPUArraysCore: @allowscalar
22
using JLArrays: JLArray
3-
using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray
3+
using SerializedArrays:
4+
PermutedSerializedArray, ReshapedSerializedArray, SerializedArray, SubSerializedArray
45
using StableRNGs: StableRNG
56
using Test: @test, @testset
67
using TestExtras: @constinferred
@@ -96,4 +97,17 @@ arrayts = (Array, JLArray)
9697
copyto!(y, a)
9798
b = SerializedArray(y)
9899
@test b == a
100+
101+
rng = StableRNG(123)
102+
x = arrayt(randn(rng, elt, 4, 4))
103+
y = @view x[2:3, 2:3]
104+
a = SerializedArray(a)
105+
b = @view a[2:3, 2:3]
106+
@test b isa SubSerializedArray{elt,2}
107+
c = 2b
108+
@test 2y == copy(c)
109+
@allowscalar begin
110+
b[1, 1] = 2
111+
@test @constinferred(b[1, 1]) == 2
112+
end
99113
end

0 commit comments

Comments
 (0)