Skip to content

Commit 5bffae7

Browse files
authored
Preserve types in reduce(vcat) and reduce(hcat) (#347)
* make reduce(vcat) as type-preserving as vcat() * tighten eltype on vcat+ * minor cleanup * cleaner error & error tests
1 parent a78aa11 commit 5bffae7

2 files changed

Lines changed: 68 additions & 2 deletions

File tree

src/structarray.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,36 @@ function Base.sizehint!(s::StructArray, i::Integer)
452452
return s
453453
end
454454

455+
_cateltype(::Type{T}, newcols::Tup) where {T<:Tup} = eltypes(newcols)
456+
_cateltype(::Type{T}, newcols::Tup) where {T} = T
457+
458+
function _reducecat_structarray(op, A::AbstractVector{<:StructArray})
459+
isempty(A) && return Base.mapreduce_empty(eltype, promote_type, eltype(A))
460+
cols = map(components, A)
461+
firstcols = first(cols)
462+
all(col -> keys(col) == keys(firstcols), cols) || throw(ArgumentError("StructArray columns must have matching keys."))
463+
newcols = map(key -> reduce(op, map(Base.Fix2(getindex, key), cols)), keys(firstcols))
464+
typedcols = strip_params(typeof(firstcols))(newcols)
465+
T = _cateltype(mapreduce(eltype, promote_type, A), typedcols)
466+
return StructArray{T}(typedcols)
467+
end
468+
455469
for op in [:cat, :hcat, :vcat]
456470
curried_op = Symbol(:curried, op)
457471
@eval begin
458472
function Base.$op(arg::StructArray, others::StructArray...; kwargs...)
459473
$curried_op(A...) = $op(A...; kwargs...)
460474
args = (arg, others...)
461-
T = mapreduce(eltype, promote_type, args)
462-
StructArray{T}(map($curried_op, map(components, args)...))
475+
newcols = map($curried_op, map(components, args)...)
476+
T = _cateltype(mapreduce(eltype, promote_type, args), newcols)
477+
StructArray{T}(newcols)
463478
end
464479
end
465480
end
466481

482+
Base.reduce(::typeof(vcat), A::AbstractVector{<:StructArray}) = _reducecat_structarray(vcat, A)
483+
Base.reduce(::typeof(hcat), A::AbstractVector{<:StructArray}) = _reducecat_structarray(hcat, A)
484+
467485
Base.copy(s::StructArray{T}) where {T} = StructArray{T}(map(copy, components(s)))
468486

469487
for type in (

test/runtests.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,54 @@ end
693693
@test @inferred(vcat(t3)) == t3
694694
@inferred vcat(t3, t3)
695695
@inferred vcat(t3, collect(t3))
696+
a = StructArray(y = Union{Missing, Int}[missing])
697+
b = StructArray(y = [3])
698+
c = StructArray(y = Union{Missing, Int}[4])
699+
vcatted = vcat(a, b, c)
700+
@test eltype(vcatted) === NamedTuple{(:y,), Tuple{Union{Missing, Int}}}
701+
reduced_vcat = reduce(vcat, [a, b, c])
702+
@test eltype(reduced_vcat) === eltype(vcatted)
703+
@test isequal(reduced_vcat, vcatted)
704+
@test reduced_vcat.y isa Vector{Union{Missing, Int}}
705+
hcatted = hcat(reshape(a, 1, 1), reshape(b, 1, 1), reshape(c, 1, 1))
706+
@test eltype(hcatted) === NamedTuple{(:y,), Tuple{Union{Missing, Int}}}
707+
reduced_hcat = reduce(hcat, [reshape(a, 1, 1), reshape(b, 1, 1), reshape(c, 1, 1)])
708+
@test eltype(reduced_hcat) === eltype(hcatted)
709+
@test isequal(reduced_hcat, hcatted)
710+
@test reduced_hcat.y isa Matrix{Union{Missing, Int}}
711+
712+
struct CatTestType{A, B}
713+
a::A
714+
b::B
715+
end
716+
custom_a = StructArray{CatTestType{Int, Missing}}((a = [1], b = Missing[missing]))
717+
custom_b = StructArray{CatTestType{Int, Int}}((a = [2], b = [3]))
718+
custom_vcat = vcat(custom_a, custom_b, custom_a)
719+
@test custom_vcat == CatTestType{Int}[CatTestType(1, missing), CatTestType(2, 3), CatTestType(1, missing)]
720+
@test custom_vcat.b isa Vector{Union{Missing, Int}}
721+
reduced_custom_vcat = reduce(vcat, [custom_a, custom_b, custom_a])
722+
@test isequal(reduced_custom_vcat, custom_vcat)
723+
@test eltype(reduced_custom_vcat) === eltype(custom_vcat) === CatTestType{Int}
724+
@test reduced_custom_vcat.b isa Vector{Union{Missing, Int}}
725+
726+
# error behavior is consistent between reduce(vcat) and vcat(), and is generally reasonable
727+
mismatched_names_a = StructArray(a = [1], b = [2])
728+
mismatched_names_b = StructArray(x = [3], y = [4])
729+
@test_throws ArgumentError vcat(mismatched_names_a, mismatched_names_b)
730+
@test_throws ArgumentError reduce(vcat, [mismatched_names_a, mismatched_names_b])
731+
mixed_rowtype_a = StructArray(re = [1.0], im = [2.0])
732+
mixed_rowtype_b = StructArray(ComplexF64[3 + 4im])
733+
@test_throws ArgumentError vcat(mixed_rowtype_a, mixed_rowtype_b)
734+
@test_throws ArgumentError reduce(vcat, [mixed_rowtype_a, mixed_rowtype_b])
735+
different_names_a = StructArray(a = [1])
736+
different_names_b = StructArray(x = [2], y = [3], z = [4])
737+
@test_throws ArgumentError vcat(different_names_a, different_names_b)
738+
@test_throws ArgumentError reduce(vcat, [different_names_a, different_names_b])
739+
different_lengths_a = StructArray(([1], [2], [3]))
740+
different_lengths_b = StructArray(([4], [5]))
741+
@test_throws ArgumentError reduce(vcat, [different_lengths_a, different_lengths_b])
742+
@test_throws ArgumentError reduce(hcat, [reshape(different_lengths_a, 1, 1), reshape(different_lengths_b, 1, 1)])
743+
696744
# Check that `cat(dims=1)` doesn't commit type piracy (#254)
697745
# We only test that this works, the return value is immaterial
698746
@test cat(dims=1) == vcat()

0 commit comments

Comments
 (0)