diff --git a/src/structarray.jl b/src/structarray.jl index eba837b..64fa371 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -516,25 +516,30 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict, Style using Base.Broadcast: combine_styles struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end - -# Here we define the dimension tracking behavior of StructArrayStyle -function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} - T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S - return StructArrayStyle{T, N}() +StructArrayStyle(::S) where S<:AbstractArrayStyle{N} where N = StructArrayStyle{S, N}() +StructArrayStyle(::S) where {S<:StructArrayStyle} = S() +StructArrayStyle(::S, ::Val{N}) where {S,N} = StructArrayStyle(S(Val(N))) +StructArrayStyle(::Val{N}) where {N} = StructArrayStyle{DefaultArrayStyle{N}, N}() +function StructArrayStyle(a::BroadcastStyle, b::BroadcastStyle) + # This is a hack so if we have an ArrayConflict it gets wrapped in StructArrayStyle + inner_style = Broadcast.result_style(a, b) + if inner_style isa Unknown + return Unknown() + else + return StructArrayStyle(inner_style) + end end -# StructArrayStyle is a wrapped style. -# Here we try our best to resolve style conflict. -function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M} - N′ = M === Any || N === Any ? Any : max(M, N) - S′ = Broadcast.result_style(S(), b) - return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}() -end -BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() +BroadcastStyle(::StructArrayStyle, ::Unknown) = Unknown() +BroadcastStyle(::StructArrayStyle{A}, ::StructArrayStyle{B}) where {A, B} = StructArrayStyle(A(), B()) +BroadcastStyle(::StructArrayStyle{S}, b::AbstractArrayStyle) where {S} = StructArrayStyle(S(), b) +BroadcastStyle(::StructArrayStyle{S}, b::DefaultArrayStyle) where {S} = StructArrayStyle(S(), b) +BroadcastStyle(::StructArrayStyle{S}, b::Style{Tuple}) where {S} = StructArrayStyle(S(), b) + @inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(BroadcastStyle(A), args...) diff --git a/test/runtests.jl b/test/runtests.jl index 905bcfa..4ded9c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1358,7 +1358,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS ares = map(a->a.re, as) aims = map(a->a.im, as) style = Broadcast.combine_styles(ares...) - @test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}() + @test Broadcast.combine_styles(as...) isa StructArrayStyle{typeof(style)} if !(style in tested_style) push!(tested_style, style) if style isa Broadcast.ArrayStyle{MyArray3} @@ -1374,9 +1374,9 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS # test for dimensionality track s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} - @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} - @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} - @test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} + @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle} + @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle} + @test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle} @test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} #parent_style @@ -1473,6 +1473,15 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS d = identity.(c) @test d isa SparseMatrixCSC end + + # Regression test: StructArray + SparseArray broadcasting should not + # error with ambiguity (DimensionalData.jl#1195) + @testset "StructArray and SparseArray broadcast" begin + sa = StructArray{ComplexF64}((rand(10), rand(10))) + sp = sprand(10, 0.5) + @test (sa .+ sp) == (collect(sa) .+ collect(sp)) + @test (sp .+ sa) == (collect(sp) .+ collect(sa)) + end end @testset "map" begin