Skip to content

Commit cbb83da

Browse files
committed
Delete Broadcast module
1 parent b03ea5d commit cbb83da

3 files changed

Lines changed: 35 additions & 37 deletions

File tree

src/abstractblocksparsearray/broadcast.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice
22
using Base.Broadcast: BroadcastStyle
33

44
function Base.Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray})
5-
return Broadcast.BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
5+
return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
66
end
77

88
# Fix ambiguity error with `BlockArrays`.
@@ -16,7 +16,7 @@ function Base.Broadcast.BroadcastStyle(
1616
},
1717
},
1818
)
19-
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
19+
return BlockSparseArrayStyle{ndims(arraytype)}()
2020
end
2121
function Base.Broadcast.BroadcastStyle(
2222
arraytype::Type{
@@ -32,7 +32,7 @@ function Base.Broadcast.BroadcastStyle(
3232
},
3333
},
3434
)
35-
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
35+
return BlockSparseArrayStyle{ndims(arraytype)}()
3636
end
3737
function Base.Broadcast.BroadcastStyle(
3838
arraytype::Type{
@@ -44,7 +44,7 @@ function Base.Broadcast.BroadcastStyle(
4444
},
4545
},
4646
)
47-
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
47+
return BlockSparseArrayStyle{ndims(arraytype)}()
4848
end
4949

5050
# These catch cases that aren't caught by the standard
@@ -59,7 +59,7 @@ function Base.copyto!(
5959
return copyto!_blocksparse(dest, bc)
6060
end
6161
function Base.copyto!(
62-
dest::AnyAbstractBlockSparseArray{<:Any, N}, bc::Broadcasted{<:Broadcast.BlockSparseArrayStyle{N}}
62+
dest::AnyAbstractBlockSparseArray{<:Any, N}, bc::Broadcasted{<:BlockSparseArrayStyle{N}}
6363
) where {N}
6464
return copyto!_blocksparse(dest, bc)
6565
end

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,55 @@ using Base.Broadcast: BroadcastStyle, Broadcasted
22
using GPUArraysCore: @allowscalar
33
using MapBroadcast: Mapped
44

5-
module Broadcast
6-
using Base.Broadcast: AbstractArrayStyle
7-
abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
8-
AbstractArrayStyle{N} end
9-
struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
10-
AbstractBlockSparseArrayStyle{N, B}
11-
blockstyle::B
12-
end
13-
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
14-
return BlockSparseArrayStyle{N, typeof(blockstyle)}(blockstyle)
15-
end
16-
function BlockSparseArrayStyle{N, B}() where {N, B <: AbstractArrayStyle{N}}
17-
return BlockSparseArrayStyle{N, B}(B())
18-
end
19-
function BlockSparseArrayStyle{N}() where {N}
20-
return BlockSparseArrayStyle{N}(Base.Broadcast.DefaultArrayStyle{N}())
21-
end
22-
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
23-
BlockSparseArrayStyle{M}(::Val{N}) where {M, N} = BlockSparseArrayStyle{N}()
24-
function BlockSparseArrayStyle{M, B}(::Val{N}) where {M, B <: AbstractArrayStyle{M}, N}
25-
return BlockSparseArrayStyle{N}(B(Val(N)))
26-
end
5+
using Base.Broadcast: AbstractArrayStyle
6+
abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
7+
AbstractArrayStyle{N} end
8+
struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
9+
AbstractBlockSparseArrayStyle{N, B}
10+
blockstyle::B
11+
end
12+
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
13+
return BlockSparseArrayStyle{N, typeof(blockstyle)}(blockstyle)
14+
end
15+
function BlockSparseArrayStyle{N, B}() where {N, B <: AbstractArrayStyle{N}}
16+
return BlockSparseArrayStyle{N, B}(B())
17+
end
18+
function BlockSparseArrayStyle{N}() where {N}
19+
return BlockSparseArrayStyle{N}(Base.Broadcast.DefaultArrayStyle{N}())
20+
end
21+
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
22+
BlockSparseArrayStyle{M}(::Val{N}) where {M, N} = BlockSparseArrayStyle{N}()
23+
function BlockSparseArrayStyle{M, B}(::Val{N}) where {M, B <: AbstractArrayStyle{M}, N}
24+
return BlockSparseArrayStyle{N}(B(Val(N)))
2725
end
2826

2927
function blockstyle(
30-
::Broadcast.AbstractBlockSparseArrayStyle{N, B},
28+
::AbstractBlockSparseArrayStyle{N, B},
3129
) where {N, B <: Base.Broadcast.AbstractArrayStyle{N}}
3230
return B()
3331
end
3432

3533
function Base.Broadcast.BroadcastStyle(
36-
style1::Broadcast.AbstractBlockSparseArrayStyle,
37-
style2::Broadcast.AbstractBlockSparseArrayStyle,
34+
style1::AbstractBlockSparseArrayStyle,
35+
style2::AbstractBlockSparseArrayStyle,
3836
)
3937
style = Base.Broadcast.result_style(blockstyle(style1), blockstyle(style2))
40-
return Broadcast.BlockSparseArrayStyle(style)
38+
return BlockSparseArrayStyle(style)
4139
end
4240

43-
Base.Broadcast.BroadcastStyle(a::Broadcast.BlockSparseArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
41+
Base.Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
4442
function Base.Broadcast.BroadcastStyle(
45-
::Broadcast.BlockSparseArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle
43+
::BlockSparseArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle
4644
) where {N}
4745
return Base.Broadcast.BroadcastStyle(Base.Broadcast.DefaultArrayStyle{N}(), a)
4846
end
4947
function Base.Broadcast.BroadcastStyle(
50-
::Broadcast.BlockSparseArrayStyle{N}, ::Base.Broadcast.Style{Tuple}
48+
::BlockSparseArrayStyle{N}, ::Base.Broadcast.Style{Tuple}
5149
) where {N}
5250
return Base.Broadcast.DefaultArrayStyle{N}()
5351
end
5452

55-
function Base.similar(bc::Broadcasted{<:Broadcast.BlockSparseArrayStyle}, elt::Type, ax)
53+
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
5654
# Find the first array in the broadcast expression.
5755
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
5856
bc′ = Base.Broadcast.flatten(bc)
@@ -84,7 +82,7 @@ end
8482

8583
# Broadcasting implementation
8684
function Base.copyto!(
87-
dest::AbstractArray{<:Any, N}, bc::Broadcasted{Broadcast.BlockSparseArrayStyle{N}}
85+
dest::AbstractArray{<:Any, N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
8886
) where {N}
8987
return copyto!_blocksparse(dest, bc)
9088
end

src/blocksparsearrayinterface/cat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using BlockArrays: blocks
22
using FunctionImplementations.Concatenate: Concatenated, cat!
33

44
function Base.copyto!(
5-
dest::AbstractArray, concat::Concatenated{<:Broadcast.BlockSparseArrayStyle}
5+
dest::AbstractArray, concat::Concatenated{<:BlockSparseArrayStyle}
66
)
77
# TODO: This assumes the destination blocking is commensurate with
88
# the blocking of the sources, for example because it was constructed

0 commit comments

Comments
 (0)