Skip to content

Commit 27040e2

Browse files
Merge pull request #510 from ChrisRackauckas-Claude/fix-raggedend-broadcasting
Make RaggedEnd and RaggedRange broadcast as scalars
2 parents aff3a71 + 3219570 commit 27040e2

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/vector_of_array.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ Base.:+(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset + Int(n))
446446
Base.:-(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset - Int(n))
447447
Base.:+(n::Integer, re::RaggedEnd) = re + n
448448

449+
# Make RaggedEnd and RaggedRange broadcast as scalars to avoid
450+
# issues with collect/length in broadcasting contexts (e.g., SymbolicIndexingInterface)
451+
Base.broadcastable(x::RaggedEnd) = Ref(x)
452+
449453
struct RaggedRange
450454
dim::Int
451455
start::Int
@@ -460,6 +464,7 @@ end
460464
function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
461465
RaggedRange(stop.dim, Int(start), Int(step), stop.offset)
462466
end
467+
Base.broadcastable(x::RaggedRange) = Ref(x)
463468

464469
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
465470
length(VA.u) <= 1 && return false

test/basic_indexing.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
196196
@test ragged2[1:(end - 1), 2] == [5.0]
197197
@test ragged2[1:(end - 1), 3] == [7.0, 8.0]
198198

199+
# Test that RaggedEnd and RaggedRange broadcast as scalars
200+
# (fixes issue with SymbolicIndexingInterface where broadcasting over RaggedEnd would fail)
201+
ragged_idx = lastindex(ragged, 1)
202+
@test ragged_idx isa RecursiveArrayTools.RaggedEnd
203+
@test Base.broadcastable(ragged_idx) isa Ref
204+
# Broadcasting with RaggedEnd should not error
205+
@test identity.(ragged_idx) === ragged_idx
206+
207+
ragged_range_idx = 1:lastindex(ragged, 1)
208+
@test ragged_range_idx isa RecursiveArrayTools.RaggedRange
209+
@test Base.broadcastable(ragged_range_idx) isa Ref
210+
# Broadcasting with RaggedRange should not error
211+
@test identity.(ragged_range_idx) === ragged_range_idx
212+
199213
# Broadcasting of heterogeneous arrays (issue #454)
200214
u = VectorOfArray([[1.0], [2.0, 3.0]])
201215
@test length(view(u, :, 1)) == 1
@@ -220,8 +234,8 @@ u[1, [1, 3], 2] .= [7.0, 9.0]
220234

221235
# 3D inner arrays (tensors) with ragged third dimension
222236
u = VectorOfArray([zeros(2, 1, n) for n in (2, 3)])
223-
@test size(view(u,:,:,:,1)) == (2, 1, 2)
224-
@test size(view(u,:,:,:,2)) == (2, 1, 3)
237+
@test size(view(u, :, :, :, 1)) == (2, 1, 2)
238+
@test size(view(u, :, :, :, 2)) == (2, 1, 3)
225239
# assign into a slice of the second inner array using last index Int
226240
u[2, 1, :, 2] .= [7.0, 8.0, 9.0]
227241
@test vec(u.u[2][2, 1, :]) == [7.0, 8.0, 9.0]

0 commit comments

Comments
 (0)