Skip to content

Commit b46e204

Browse files
sshin23claude
andcommitted
Fix append! for matrix bounds in batch models
Add _reshape_to_match to correctly handle array bounds regardless of input shape. For non-batch (Vector target), matrices are vecd. For batch (Matrix target), matrices are reshaped to match trailing dims. This fixes the DimensionMismatch when passing matrix lvar/uvar to batch models, while preserving the existing behavior for non-batch models that receive matrix-shaped comprehension results. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f13f998 commit b46e204

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/nlp.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,17 @@ function _expand_to_shape(col::AbstractVector{T}, trailing::Tuple) where {T}
680680
return repeat(reshape(col, :, ntuple(_ -> 1, length(trailing))...), 1, trailing...)
681681
end
682682

683+
# Reshape an array to have the given trailing dimensions.
684+
# For vectors, delegates to _expand_to_shape. For matrices/higher-dim arrays,
685+
# reshapes using total elements / trailing to infer the first dimension.
686+
_reshape_to_match(arr::AbstractVector, trailing::Tuple{}) = arr
687+
_reshape_to_match(arr::AbstractVector, trailing::Tuple) = _expand_to_shape(arr, trailing)
688+
_reshape_to_match(arr::AbstractArray, trailing::Tuple{}) = vec(arr)
689+
function _reshape_to_match(arr::AbstractArray, trailing::Tuple)
690+
n = length(arr) ÷ prod(trailing)
691+
return reshape(arr, n, trailing...)
692+
end
693+
683694
function append!(backend, a, b::Number, lb)
684695
lb == 0 && return a
685696
new_part = fill(eltype(a)(b), lb, _trailing_dims(a)...)
@@ -689,8 +700,7 @@ end
689700
function append!(backend, a, b::AbstractArray, lb)
690701
lb == 0 && return a
691702
arr = convert_array(b, backend)
692-
col = vec(arr)
693-
return cat(a, _expand_to_shape(col, _trailing_dims(a)); dims = 1)
703+
return cat(a, _reshape_to_match(arr, _trailing_dims(a)); dims = 1)
694704
end
695705

696706
function append!(backend, a, b::Base.Generator, lb)

0 commit comments

Comments
 (0)