Skip to content

Commit 0aa6da5

Browse files
Merge pull request SciML#520 from ChrisRackauckas-Claude/fix-issue-496-gpu-zeromatrix
Fix zeromatrix to preserve GPU array types for ArrayPartition
2 parents adf6fb5 + fbfcc7f commit 0aa6da5

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

src/array_partition.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ end
548548
## Linear Algebra
549549

550550
function ArrayInterface.zeromatrix(A::ArrayPartition)
551-
x = reduce(vcat, vec.(A.x))
551+
# Use foldl with explicit init to preserve array type (important for GPU arrays)
552+
# Starting with vec of first element ensures the result type matches the input
553+
vecs = vec.(A.x)
554+
rest = Base.tail(vecs)
555+
x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1])
552556
return x .* x' .* false
553557
end
554558

src/named_array_partition.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ end
174174
#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
175175
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
176176
B = ArrayPartition(A)
177-
x = reduce(vcat, vec.(B.x))
177+
# Use foldl with explicit init to preserve array type (important for GPU arrays)
178+
vecs = vec.(B.x)
179+
rest = Base.tail(vecs)
180+
x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1])
178181
return x .* x' .* false
179182
end
180183

0 commit comments

Comments
 (0)