Skip to content

Commit fbfcc7f

Browse files
committed
Fix zeromatrix to preserve GPU array types for ArrayPartition
The previous implementation used `reduce(vcat, vec.(A.x))` which could cause type conversion issues with GPU arrays, leading to scalar indexing errors when using implicit ODE solvers with ArrayPartition of CuArrays. The fix uses `foldl` with an explicit `init` value from the first element of the tuple, ensuring the result array type matches the input type. This preserves GPU array types (CuArray, MtlArray, etc.) when building the zero matrix. Fixes #496 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent adf6fb5 commit fbfcc7f

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/array_partition.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ end
548548
## Linear Algebra
549549

550550
function ArrayInterface.zeromatrix(A::ArrayPartition)
551-
x = reduce(vcat, vec.(A.x))
551+
# Use foldl with explicit init to preserve array type (important for GPU arrays)
552+
# Starting with vec of first element ensures the result type matches the input
553+
vecs = vec.(A.x)
554+
rest = Base.tail(vecs)
555+
x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1])
552556
return x .* x' .* false
553557
end
554558

src/named_array_partition.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ end
174174
#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
175175
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
176176
B = ArrayPartition(A)
177-
x = reduce(vcat, vec.(B.x))
177+
# Use foldl with explicit init to preserve array type (important for GPU arrays)
178+
vecs = vec.(B.x)
179+
rest = Base.tail(vecs)
180+
x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1])
178181
return x .* x' .* false
179182
end
180183

0 commit comments

Comments
 (0)