Skip to content

Commit 019aa55

Browse files
committed
Fix
1 parent 39445fb commit 019aa55

1 file changed

Lines changed: 15 additions & 11 deletions

File tree

src/batch.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,17 @@ function hess_coord!(
402402
nh = m.nnzh_per
403403
perm = m.hess_perm
404404

405-
if allequal(bobj_weight)
405+
# Move perm to device for GPU-compatible gather indexing
406+
perm_dev = similar(m.hess_buffer, Int, length(perm))
407+
copyto!(perm_dev, perm)
408+
409+
bobj_weight_cpu = Array(bobj_weight)
410+
if allequal(bobj_weight_cpu)
406411
# Common case: uniform obj_weight → single fused call + permute
407-
w = bobj_weight[1]
412+
w = bobj_weight_cpu[1]
408413
hess_coord!(m.model, x_flat, y_flat, m.hess_buffer; obj_weight = w)
409414
bhvals_flat = vec(bhvals)
410-
for i in eachindex(perm)
411-
bhvals_flat[i] = m.hess_buffer[perm[i]]
412-
end
415+
bhvals_flat .= m.hess_buffer[perm_dev]
413416
else
414417
# Varying weights: 2-pass approach
415418
# Pass 1: objective hessian only (y=0, obj_weight=1)
@@ -421,13 +424,14 @@ function hess_coord!(
421424
# Pass 2: constraint hessian only (obj_weight=0)
422425
hess_coord!(m.model, x_flat, y_flat, m.hess_buffer; obj_weight = zero(eltype(x_flat)))
423426

424-
# Combine per scenario
427+
# Build per-element weight vector: element i belongs to scenario (i-1)÷nh+1
428+
w_cpu = [bobj_weight_cpu[(i - 1) ÷ nh + 1] for i in 1:length(perm)]
429+
w_dev = similar(bobj_weight, length(perm))
430+
copyto!(w_dev, w_cpu)
431+
432+
# Combine per scenario (vectorized for GPU)
425433
bhvals_flat = vec(bhvals)
426-
for i in eachindex(perm)
427-
s = (i - 1) ÷ nh + 1
428-
bhvals_flat[i] =
429-
bobj_weight[s] * hess_obj[perm[i]] + m.hess_buffer[perm[i]]
430-
end
434+
bhvals_flat .= w_dev .* hess_obj[perm_dev] .+ m.hess_buffer[perm_dev]
431435
end
432436
return bhvals
433437
end

0 commit comments

Comments
 (0)