@@ -2,57 +2,55 @@ using Base.Broadcast: BroadcastStyle, Broadcasted
22using GPUArraysCore: @allowscalar
33using MapBroadcast: Mapped
44
5- module Broadcast
6- using Base. Broadcast: AbstractArrayStyle
7- abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N} } < :
8- AbstractArrayStyle{N} end
9- struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N} } < :
10- AbstractBlockSparseArrayStyle{N, B}
11- blockstyle:: B
12- end
13- function BlockSparseArrayStyle {N} (blockstyle:: AbstractArrayStyle{N} ) where {N}
14- return BlockSparseArrayStyle {N, typeof(blockstyle)} (blockstyle)
15- end
16- function BlockSparseArrayStyle {N, B} () where {N, B <: AbstractArrayStyle{N} }
17- return BlockSparseArrayStyle {N, B} (B ())
18- end
19- function BlockSparseArrayStyle {N} () where {N}
20- return BlockSparseArrayStyle {N} (Base. Broadcast. DefaultArrayStyle {N} ())
21- end
22- BlockSparseArrayStyle (:: Val{N} ) where {N} = BlockSparseArrayStyle {N} ()
23- BlockSparseArrayStyle {M} (:: Val{N} ) where {M, N} = BlockSparseArrayStyle {N} ()
24- function BlockSparseArrayStyle {M, B} (:: Val{N} ) where {M, B <: AbstractArrayStyle{M} , N}
25- return BlockSparseArrayStyle {N} (B (Val (N)))
26- end
5+ using Base. Broadcast: AbstractArrayStyle
6+ abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N} } < :
7+ AbstractArrayStyle{N} end
8+ struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N} } < :
9+ AbstractBlockSparseArrayStyle{N, B}
10+ blockstyle:: B
11+ end
12+ function BlockSparseArrayStyle {N} (blockstyle:: AbstractArrayStyle{N} ) where {N}
13+ return BlockSparseArrayStyle {N, typeof(blockstyle)} (blockstyle)
14+ end
15+ function BlockSparseArrayStyle {N, B} () where {N, B <: AbstractArrayStyle{N} }
16+ return BlockSparseArrayStyle {N, B} (B ())
17+ end
18+ function BlockSparseArrayStyle {N} () where {N}
19+ return BlockSparseArrayStyle {N} (Base. Broadcast. DefaultArrayStyle {N} ())
20+ end
21+ BlockSparseArrayStyle (:: Val{N} ) where {N} = BlockSparseArrayStyle {N} ()
22+ BlockSparseArrayStyle {M} (:: Val{N} ) where {M, N} = BlockSparseArrayStyle {N} ()
23+ function BlockSparseArrayStyle {M, B} (:: Val{N} ) where {M, B <: AbstractArrayStyle{M} , N}
24+ return BlockSparseArrayStyle {N} (B (Val (N)))
2725end
2826
2927function blockstyle (
30- :: Broadcast. AbstractBlockSparseArrayStyle{N, B} ,
28+ :: AbstractBlockSparseArrayStyle{N, B} ,
3129 ) where {N, B <: Base.Broadcast.AbstractArrayStyle{N} }
3230 return B ()
3331end
3432
3533function Base. Broadcast. BroadcastStyle (
36- style1:: Broadcast. AbstractBlockSparseArrayStyle ,
37- style2:: Broadcast. AbstractBlockSparseArrayStyle ,
34+ style1:: AbstractBlockSparseArrayStyle ,
35+ style2:: AbstractBlockSparseArrayStyle ,
3836 )
3937 style = Base. Broadcast. result_style (blockstyle (style1), blockstyle (style2))
40- return Broadcast . BlockSparseArrayStyle (style)
38+ return BlockSparseArrayStyle (style)
4139end
4240
43- Base. Broadcast. BroadcastStyle (a:: Broadcast. BlockSparseArrayStyle , :: Base.Broadcast.DefaultArrayStyle{0} ) = a
41+ Base. Broadcast. BroadcastStyle (a:: BlockSparseArrayStyle , :: Base.Broadcast.DefaultArrayStyle{0} ) = a
4442function Base. Broadcast. BroadcastStyle (
45- :: Broadcast. BlockSparseArrayStyle{N} , a:: Base.Broadcast.DefaultArrayStyle
43+ :: BlockSparseArrayStyle{N} , a:: Base.Broadcast.DefaultArrayStyle
4644 ) where {N}
4745 return Base. Broadcast. BroadcastStyle (Base. Broadcast. DefaultArrayStyle {N} (), a)
4846end
4947function Base. Broadcast. BroadcastStyle (
50- :: Broadcast. BlockSparseArrayStyle{N} , :: Base.Broadcast.Style{Tuple}
48+ :: BlockSparseArrayStyle{N} , :: Base.Broadcast.Style{Tuple}
5149 ) where {N}
5250 return Base. Broadcast. DefaultArrayStyle {N} ()
5351end
5452
55- function Base. similar (bc:: Broadcasted{<:Broadcast. BlockSparseArrayStyle} , elt:: Type , ax)
53+ function Base. similar (bc:: Broadcasted{<:BlockSparseArrayStyle} , elt:: Type , ax)
5654 # Find the first array in the broadcast expression.
5755 # TODO : Make this more generic, base it off sure this handles GPU arrays properly.
5856 bc′ = Base. Broadcast. flatten (bc)
8482
8583# Broadcasting implementation
8684function Base. copyto! (
87- dest:: AbstractArray{<:Any, N} , bc:: Broadcasted{Broadcast. BlockSparseArrayStyle{N}}
85+ dest:: AbstractArray{<:Any, N} , bc:: Broadcasted{BlockSparseArrayStyle{N}}
8886 ) where {N}
8987 return copyto!_blocksparse (dest, bc)
9088end
0 commit comments