Skip to content

Commit 41740c3

Browse files
authored
More efficient cubic functions (#233)
* More efficient cubic functions * fix int64 * add test
1 parent 6f3bd15 commit 41740c3

File tree

6 files changed

+365
-46
lines changed

6 files changed

+365
-46
lines changed

src/MOI_wrapper.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ function MOI.supports(
600600
# needing to support names for both quadratic and affine constraints.
601601
# TODO:
602602
# switch to only check support name for the case of linear
603-
# is a solver does not support quadratic constraints it will fain in add_
603+
# is a solver does not support quadratic constraints it will fail in add_
604604
if MOI.supports_constraint(model.optimizer, F, S)
605605
return MOI.supports(model.optimizer, attr, MOI.ConstraintIndex{F,S}) &&
606606
MOI.supports(model.optimizer, attr, MOI.ConstraintIndex{G,S})
@@ -1910,7 +1910,7 @@ end
19101910
#
19111911

19121912
function MOI.optimize!(model::Optimizer)
1913-
if !isempty(model.updated_parameters)
1913+
if any(!isnan, values(model.updated_parameters))
19141914
MOI.Utilities.final_touch(model, nothing)
19151915
update_parameters!(model)
19161916
end

src/cubic_objective.jl

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,40 +136,62 @@ function _try_incremental_cubic_update!(
136136
# Get the current objective function type from the inner optimizer
137137
F = MOI.get(model.optimizer, MOI.ObjectiveFunctionType())
138138

139-
# Compute full new values (not deltas) for robustness
140-
# The delta was used to detect changes; now apply full new coefficients
141-
new_quad_terms = _parametric_quadratic_terms(model, pf)
142-
new_affine_terms = _parametric_affine_terms(model, pf)
143-
new_constant = _parametric_constant(model, pf)
144-
145-
# Apply quadratic coefficient changes
139+
# Apply quadratic coefficient changes.
140+
# For each changed (var1, var2) pair, recompute its new coefficient from the
141+
# base vv (pf.quadratic_data) data plus current pvv contributions
142+
# (avoids full copy + full iteration).
146143
# MOI convention:
147144
# - Off-diagonal (v1 != v2): coefficient C means C*v1*v2 (use as-is)
148145
# - Diagonal (v1 == v2): coefficient C means (C/2)*v1^2 (multiply by 2)
149-
for ((var1, var2), _) in delta_quadratic
150-
new_coef = new_quad_terms[(var1, var2)]
151-
# Apply MOI coefficient convention
152-
moi_coef = var1 == var2 ? new_coef * 2 : new_coef
146+
for (var1, var2) in keys(delta_quadratic)
147+
new_coef = get(pf.quadratic_data, (var1, var2), zero(T))
148+
for term in pf.pvv
149+
p = term.index_1
150+
first_is_greater = term.index_2.value > term.index_3.value
151+
v1 = ifelse(first_is_greater, term.index_3, term.index_2)
152+
v2 = ifelse(first_is_greater, term.index_2, term.index_3)
153+
if (v1, v2) == (var1, var2)
154+
new_coef +=
155+
term.coefficient * _effective_param_value(model, p_idx(p))
156+
end
157+
end
158+
moi_coef = new_coef * ifelse(var1 == var2, 2, 1)
153159
MOI.modify(
154160
model.optimizer,
155161
MOI.ObjectiveFunction{F}(),
156162
MOI.ScalarQuadraticCoefficientChange(var1, var2, moi_coef),
157163
)
158164
end
159165

160-
# Apply affine coefficient changes (use full new coefficient)
161-
for (var, _) in delta_affine
162-
new_coef = new_affine_terms[var]
166+
# Apply affine coefficient changes.
167+
# For each changed variable, recompute its new coefficient from the base
168+
# affine_data plus current pv and ppv contributions.
169+
for var in keys(delta_affine)
170+
new_coef = get(pf.affine_data, var, zero(T))
171+
for term in pf.pv
172+
if term.variable_2 == var
173+
new_coef +=
174+
term.coefficient *
175+
_effective_param_value(model, p_idx(term.variable_1))
176+
end
177+
end
178+
for term in pf.ppv
179+
if term.index_3 == var
180+
p1_val = _effective_param_value(model, p_idx(term.index_1))
181+
p2_val = _effective_param_value(model, p_idx(term.index_2))
182+
new_coef += term.coefficient * p1_val * p2_val
183+
end
184+
end
163185
MOI.modify(
164186
model.optimizer,
165187
MOI.ObjectiveFunction{F}(),
166188
MOI.ScalarCoefficientChange(var, new_coef),
167189
)
168190
end
169191

170-
# Apply constant change
192+
# Apply constant change using the tracked current_constant (no full recompute).
171193
if !iszero(delta_constant)
172-
pf.current_constant = new_constant
194+
pf.current_constant += delta_constant
173195
MOI.modify(
174196
model.optimizer,
175197
MOI.ObjectiveFunction{F}(),

src/cubic_parser.jl

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,24 +295,92 @@ function _expand_power(args, ::Type{T}) where {T}
295295
return result
296296
end
297297

298+
"""
299+
_sort3(a, b, c) -> (a, b, c) sorted ascending
300+
301+
Sort three integers without heap allocation (in-place bubble sort).
302+
"""
303+
function _sort3(a::Real, b::Real, c::Real)
304+
if a > b
305+
a, b = b, a
306+
end
307+
if b > c
308+
b, c = c, b
309+
end
310+
if a > b
311+
a, b = b, a
312+
end
313+
return a, b, c
314+
end
315+
316+
"""
317+
_monomial_key(m::_Monomial)::NTuple{4,Int64}
318+
319+
Compute a canonical hash key for a monomial: (degree, sorted_val1, sorted_val2, sorted_val3).
320+
Uses integer tuple instead of a sorted Vector for faster hashing.
321+
"""
322+
function _monomial_key(m::_Monomial)
323+
n = length(m.variables)
324+
if n == 0
325+
return (Int64(0), Int64(0), Int64(0), Int64(0))
326+
elseif n == 1
327+
a = m.variables[1].value
328+
return (Int64(1), Int64(a), Int64(0), Int64(0))
329+
elseif n == 2
330+
a, b = m.variables[1].value, m.variables[2].value
331+
lo, hi = a <= b ? (a, b) : (b, a)
332+
return (Int64(2), Int64(lo), Int64(hi), Int64(0))
333+
else # n >= 3; degree > 3 is rejected at classification stage
334+
a, b, c = _sort3(
335+
m.variables[1].value,
336+
m.variables[2].value,
337+
m.variables[3].value,
338+
)
339+
return (Int64(n), Int64(a), Int64(b), Int64(c))
340+
end
341+
end
342+
343+
"""
344+
_monomial_vars(key::NTuple{4,Int64})::Vector{MOI.VariableIndex}
345+
346+
Given a monomial key, reconstruct the list of variables.
347+
"""
348+
function _monomial_vars(key::NTuple{4,Int64})
349+
degree = key[1]
350+
if degree == 0
351+
return MOI.VariableIndex[]
352+
elseif degree == 1
353+
return [MOI.VariableIndex(key[2])]
354+
elseif degree == 2
355+
return [MOI.VariableIndex(key[2]), MOI.VariableIndex(key[3])]
356+
else # degree == 3
357+
return [
358+
MOI.VariableIndex(key[2]),
359+
MOI.VariableIndex(key[3]),
360+
MOI.VariableIndex(key[4]),
361+
]
362+
end
363+
end
364+
298365
"""
299366
_combine_like_monomials(monomials::Vector{_Monomial{T}}) where {T}
300367
301368
Combine like monomials (same variables, regardless of order).
369+
Assumes all monomials have degree ≤ 3.
302370
"""
303371
function _combine_like_monomials(monomials::Vector{_Monomial{T}}) where {T}
304-
# Use a dict keyed by sorted variable tuple
305-
combined = Dict{Vector{MOI.VariableIndex},T}()
372+
# Key: NTuple{4,Int64} (degree + up to 3 sorted variable indices).
373+
combined = Dict{NTuple{4,Int64},T}()
306374

307375
for m in monomials
308-
# Sort variables for canonical key
309-
key = sort(m.variables, by = v -> v.value)
376+
key = _monomial_key(m)
310377
combined[key] = get(combined, key, zero(T)) + m.coefficient
311378
end
312379

313380
result = _Monomial{T}[]
314-
for (vars, coef) in combined
381+
for (key, coef) in combined
315382
if !iszero(coef)
383+
vars = _monomial_vars(key)
316384
push!(result, _Monomial{T}(coef, vars))
317385
end
318386
end
@@ -341,7 +409,7 @@ function _classify_monomial(m::_Monomial)
341409
else
342410
return :pp
343411
end
344-
elseif degree == 3
412+
else # degree == 3 (degree > 3 rejected early in _parse_cubic_expression)
345413
if num_params == 0
346414
return :vvv # Invalid - no parameter
347415
elseif num_params == 1
@@ -351,8 +419,6 @@ function _classify_monomial(m::_Monomial)
351419
else
352420
return :ppp
353421
end
354-
else
355-
return :invalid # Degree > 3
356422
end
357423
end
358424

@@ -377,12 +443,17 @@ function _parse_cubic_expression(
377443
return nothing
378444
end
379445

446+
# Reject any monomial with degree > 3 before combining
447+
for m in monomials
448+
if _monomial_degree(m) > 3
449+
return nothing
450+
end
451+
end
452+
380453
# Combine like terms
381454
monomials = _combine_like_monomials(monomials)
382455

383456
# Classify and collect terms
384-
cubic_terms = _ScalarCubicTerm{T}[]
385-
386457
cubic_ppp = _ScalarCubicTerm{T}[]
387458
cubic_ppv = _ScalarCubicTerm{T}[]
388459
cubic_pvv = _ScalarCubicTerm{T}[]
@@ -399,8 +470,8 @@ function _parse_cubic_expression(
399470
for m in monomials
400471
classification = _classify_monomial(m)
401472

402-
if classification == :invalid || classification == :vvv
403-
return nothing # Invalid degree or no parameter in cubic
473+
if classification == :vvv
474+
return nothing # No parameter in cubic term
404475
elseif classification == :constant
405476
constant += m.coefficient
406477
elseif classification == :v

src/cubic_types.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,27 @@ function _normalize_cubic_indices(
4040
idx2::MOI.VariableIndex,
4141
idx3::MOI.VariableIndex,
4242
)
43-
params = MOI.VariableIndex[]
44-
vars = MOI.VariableIndex[]
45-
for idx in (idx1, idx2, idx3)
46-
if _is_parameter(idx)
47-
push!(params, idx)
48-
else
49-
push!(vars, idx)
50-
end
43+
p1 = _is_parameter(idx1)
44+
p2 = _is_parameter(idx2)
45+
p3 = _is_parameter(idx3)
46+
# Place parameters before variables, preserving relative order within each group.
47+
if p1 && p2 && p3
48+
return idx1, idx2, idx3 # ppp — already ordered
49+
elseif p1 && p2 # p p v
50+
return idx1, idx2, idx3
51+
elseif p1 && p3 # p v p → p p v
52+
return idx1, idx3, idx2
53+
elseif p2 && p3 # v p p → p p v
54+
return idx2, idx3, idx1
55+
elseif p1 # p v v — already ordered
56+
return idx1, idx2, idx3
57+
elseif p2 # v p v → p v v
58+
return idx2, idx1, idx3
59+
elseif p3 # v v p → p v v
60+
return idx3, idx1, idx2
61+
else # v v v — no parameter (caller validates)
62+
return idx1, idx2, idx3
5163
end
52-
all_indices = vcat(params, vars)
53-
return all_indices[1], all_indices[2], all_indices[3]
5464
end
5565

5666
"""

src/parametric_cubic_function.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,9 @@ cubic_parameter_parameter_parameter_terms(f::ParametricCubicFunction) = f.ppp
143143
144144
Get the effective parameter value: updated value if available, otherwise current value.
145145
"""
146-
function _effective_param_value(model, pi::ParameterIndex)
147-
if haskey(model.updated_parameters, pi) &&
148-
!isnan(model.updated_parameters[pi])
149-
return model.updated_parameters[pi]
150-
end
151-
return model.parameters[pi]
146+
function _effective_param_value(model, p::ParameterIndex)
147+
val = model.updated_parameters[p]
148+
return isnan(val) ? model.parameters[p] : val
152149
end
153150

154151
"""

0 commit comments

Comments
 (0)