Skip to content

Commit 2bcebfb

Browse files
Merge pull request SciML#543 from JoshuaLampert/fix-similar
Fix `similar` and `fill!` for mixed nested `VectorOfArray`
2 parents 9750141 + dee3a69 commit 2bcebfb

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/vector_of_array.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,11 @@ function Base.similar(
14041404
end
14051405

14061406
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
1407-
return VectorOfArray(similar.(VA.u, T))
1407+
if eltype(VA.u) <: Union{AbstractArray, AbstractVectorOfArray}
1408+
return VectorOfArray(similar.(VA.u, T))
1409+
else
1410+
return VectorOfArray(similar(VA.u, T))
1411+
end
14081412
end
14091413

14101414
@inline function Base.similar(VA::VectorOfArray, dims::N) where {N <: Number}
@@ -1420,7 +1424,7 @@ end
14201424
# For DiffEqArray it ignores ts and fills only u
14211425
function Base.fill!(VA::AbstractVectorOfArray, x)
14221426
for i in 1:length(VA.u)
1423-
if VA[:, i] isa AbstractArray
1427+
if VA[:, i] isa Union{AbstractArray, AbstractVectorOfArray}
14241428
if ArrayInterface.ismutable(VA.u[i]) || VA.u[i] isa AbstractVectorOfArray
14251429
fill!(VA[:, i], x)
14261430
else

test/utils_test.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,23 @@ end
141141
@test u1.u[2] isa SVector
142142
end
143143

144+
@testset "VectorOfArray similar with nested scalar leaves" begin
145+
a = VectorOfArray([ones(2), VectorOfArray([1.0, 1.0])])
146+
b = similar(a, Float64)
147+
@test b isa typeof(a)
148+
@test b.u[1] isa Vector{Float64}
149+
@test b.u[2] isa typeof(a.u[2])
150+
@test b.u[2].u isa Vector{Float64}
151+
@test length(b.u[2].u) == 2
152+
end
153+
154+
@testset "recursivefill! with nested union partitions" begin
155+
a = VectorOfArray([ones(2), VectorOfArray([1.0, 1.0])])
156+
recursivefill!(a, true)
157+
@test a.u[1] == ones(2)
158+
@test a.u[2].u == ones(2)
159+
end
160+
144161
# Test recursivefill! with immutable StaticArrays (issue #461)
145162
@testset "recursivefill! with immutable StaticArrays (issue #461)" begin
146163
# Test with only immutable SVectors

0 commit comments

Comments
 (0)