Skip to content

Commit 64b5a80

Browse files
committed
Fix cat
1 parent 22c9257 commit 64b5a80

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ function Base.Broadcast.BroadcastStyle(
5353
end
5454

5555
function Base.similar(bc::Broadcasted{<:Broadcast.BlockSparseArrayStyle}, elt::Type, ax)
56+
# Find the first array in the broadcast expression.
5657
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
57-
m = Mapped(bc)
58-
return similar(first(m.args), elt, ax)
58+
bc′ = Base.Broadcast.flatten(bc)
59+
arg = bc′.args[findfirst(arg -> arg isa AbstractArray, bc′.args)]
60+
return similar(arg, elt, ax)
5961
end
6062

6163
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.

test/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ arrayts = (Array, JLArray)
426426
@test a1' * a2 Array(a1)' * Array(a2)
427427
@test dot(a1, a2) a1' * a2
428428
end
429-
false && @testset "cat" begin
429+
@testset "cat" begin
430430
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
431431
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
432432
a2 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))

0 commit comments

Comments
 (0)