forked from JuliaArrays/StructArrays.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathStructArraysStaticArraysExt.jl
More file actions
119 lines (104 loc) · 4.64 KB
/
StructArraysStaticArraysExt.jl
File metadata and controls
119 lines (104 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
module StructArraysStaticArraysExt
using StructArrays
using StaticArrays: StaticArray, FieldArray, tuple_prod, SVector, MVector, SOneTo
"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
```julia
julia> StructArrays.staticschema(SVector{2, Float64})
Tuple{Float64, Float64}
```
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
which subtypes `FieldArray`.
"""
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
return quote
Base.@_inline_meta
return NTuple{$(tuple_prod(S)), T}
end
end
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
StructArrays.component(s::StaticArray, i::Integer) = getindex(s, i)
function StructArrays.component(s::StructArray{<:Union{SVector,MVector}}, key::Symbol)
i = key == :x ? 1 :
key == :y ? 2 :
key == :z ? 3 :
key == :w ? 4 :
throw(ArgumentError("invalid key $key"))
StructArrays.component(s, i)
end
# use general fallbacks for a `FieldArray` type.
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
StructArrays.staticschema_generic(T)
end
StructArrays.component(s::FieldArray, i) = getfield(s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = StructArrays.createinstance_generic(T, args...)
# disambiguation
Base.similar(s::StructArray, S::Type, sz::Tuple{Union{Integer, Base.OneTo, SOneTo}, Vararg{Union{Union{Integer, Base.OneTo, SOneTo}}}}) = StructArrays._similar(s, S, sz)
Base.reshape(s::StructArray{T}, d::Tuple{SOneTo, Vararg{SOneTo}}) where {T} = StructArray{T}(map(x -> reshape(x, d), StructArrays.components(s)))
# Broadcast overload
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype
using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted, _broadcast_getindex
# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
ax = axes(bc)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end
# A functor generates the ith component of StructStaticBroadcast.
struct Similar_ith{SA, E<:Tuple}
elements::E
Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements)
end
function (s::Similar_ith{SA})(i::Int) where {SA}
ith_elements = ntuple(Val(length(s.elements))) do j
getfield(s.elements[j], i)
end
ith_SA = similar_type(SA, fieldtype(eltype(SA), i))
return @inbounds ith_SA(ith_elements)
end
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
first_staticarray = first_statictype(a...)
elements, ET = if prod(newsize) == 0
# Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl)
(), Base.promote_op(f, map(eltype, a)...)
else
temp = __broadcast(f, sz, s, a...)
temp, eltype(temp)
end
if isnonemptystructtype(ET)
SA = similar_type(first_staticarray, ET, sz)
arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET)))
return StructArray{ET}(arrs)
else
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
end
end
# The `__broadcast` kernal is copied from `StaticArrays.jl`.
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
sizes = [sz.parameters[1] for sz ∈ s.parameters]
indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) ∈ enumerate(indices)
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(f($(exprs_vals...)))
end
return quote
Base.@_inline_meta
return tuple($(exprs...))
end
end
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
li = LinearIndices(oldsize)
ind = _broadcast_getindex(li, newindex)
return :(a[$i][$ind])
end
end