-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathbroadcast.jl
More file actions
86 lines (75 loc) · 3.1 KB
/
broadcast.jl
File metadata and controls
86 lines (75 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
using Base.Broadcast:
Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
using GPUArraysCore: @allowscalar
using MapBroadcast: Mapped
using DerivableInterfaces: DerivableInterfaces, @interface
abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
AbstractArrayStyle{N} end
blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{N}} = B()
function Broadcast.BroadcastStyle(
style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle
)
style = Broadcast.result_style(blockstyle(style1), blockstyle(style2))
return BlockSparseArrayStyle(style)
end
function DerivableInterfaces.interface(
::Type{<:AbstractBlockSparseArrayStyle{N,B}}
) where {N,B<:AbstractArrayStyle{N}}
return BlockSparseArrayInterface(interface(B))
end
struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
AbstractBlockSparseArrayStyle{N,B}
blockstyle::B
end
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle)
end
function BlockSparseArrayStyle{N,B}() where {N,B<:AbstractArrayStyle{N}}
return BlockSparseArrayStyle{N,B}(B())
end
BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}())
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}()
function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N}
return BlockSparseArrayStyle{N}(B(Val(N)))
end
Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a
function Broadcast.BroadcastStyle(
::BlockSparseArrayStyle{N}, a::DefaultArrayStyle
) where {N}
return BroadcastStyle(DefaultArrayStyle{N}(), a)
end
function Broadcast.BroadcastStyle(
::BlockSparseArrayStyle{N}, ::Broadcast.Style{Tuple}
) where {N}
return DefaultArrayStyle{N}()
end
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
m = Mapped(bc)
return similar(first(m.args), elt, ax)
end
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.
# If the RHS is zero, this makes sure that the storage is emptied,
# which is logic that is handled by `fill!`.
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}})
# `[]` is used to unwrap zero-dimensional arrays.
bcf = Broadcast.flatten(bc)
value = @allowscalar bcf.f(map(arg -> arg[], bcf.args)...)
return @interface BlockSparseArrayInterface() fill!(dest, value)
end
# Broadcasting implementation
# TODO: Delete this in favor of `DerivableInterfaces` version.
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted)
# convert to map
# flatten and only keep the AbstractArray arguments
m = Mapped(bc)
@interface interface(dest, bc) map!(m.f, dest, m.args...)
return dest
end
function Base.copyto!(
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
) where {N}
copyto_blocksparse!(dest, bc)
return dest
end