Skip to content

Commit 5ce30a7

Browse files
committed
Interface for solver supporting diff natively
1 parent 2b4c350 commit 5ce30a7

3 files changed

Lines changed: 970 additions & 1 deletion

File tree

src/diff_opt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,40 @@ the differentiation information.
314314
"""
315315
struct DifferentiateTimeSec <: MOI.AbstractModelAttribute end
316316

317+
"""
318+
BackwardDifferentiate <: MOI.AbstractOptimizerAttribute
319+
320+
An `MOI.AbstractOptimizerAttribute` that triggers backward differentiation
321+
on the solver. If `MOI.supports(optimizer, DiffOpt.BackwardDifferentiate())`
322+
returns `true`, then the solver natively supports backward differentiation
323+
through the DiffOpt attribute interface, and DiffOpt will delegate
324+
differentiation directly to the solver instead of using its own
325+
differentiation backend.
326+
327+
Trigger the computation with:
328+
```julia
329+
MOI.set(optimizer, DiffOpt.BackwardDifferentiate(), nothing)
330+
```
331+
"""
332+
struct BackwardDifferentiate <: MOI.AbstractOptimizerAttribute end
333+
334+
"""
335+
ForwardDifferentiate <: MOI.AbstractOptimizerAttribute
336+
337+
An `MOI.AbstractOptimizerAttribute` that triggers forward differentiation
338+
on the solver. If `MOI.supports(optimizer, DiffOpt.ForwardDifferentiate())`
339+
returns `true`, then the solver natively supports forward differentiation
340+
through the DiffOpt attribute interface, and DiffOpt will delegate
341+
differentiation directly to the solver instead of using its own
342+
differentiation backend.
343+
344+
Trigger the computation with:
345+
```julia
346+
MOI.set(optimizer, DiffOpt.ForwardDifferentiate(), nothing)
347+
```
348+
"""
349+
struct ForwardDifferentiate <: MOI.AbstractOptimizerAttribute end
350+
317351
MOI.attribute_value_type(::DifferentiateTimeSec) = Float64
318352

319353
MOI.is_set_by_optimize(::DifferentiateTimeSec) = true

src/moi_wrapper.jl

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,22 @@ function reverse_differentiate!(model::Optimizer)
562562
"Set `DiffOpt.AllowObjectiveAndSolutionInput()` to `true` to silence this warning."
563563
end
564564
end
565+
if MOI.supports(model.optimizer, BackwardDifferentiate())
566+
# Solver natively supports backward differentiation.
567+
# Copy input_cache directly into model.optimizer and trigger differentiation.
568+
opt = model.optimizer
569+
for (vi, value) in model.input_cache.dx
570+
MOI.set(opt, ReverseVariablePrimal(), vi, value)
571+
end
572+
for (ci, value) in model.input_cache.dy
573+
MOI.set(opt, ReverseConstraintDual(), ci, value)
574+
end
575+
if !iszero(model.input_cache.dobj)
576+
MOI.set(opt, ReverseObjectiveSensitivity(), model.input_cache.dobj)
577+
end
578+
MOI.set(opt, BackwardDifferentiate(), nothing)
579+
return
580+
end
565581
diff = _diff(model)
566582
MOI.set(
567583
diff,
@@ -673,6 +689,34 @@ function forward_differentiate!(model::Optimizer)
673689
"Trying to compute the forward differentiation on a model with termination status $(st)",
674690
)
675691
end
692+
if MOI.supports(model.optimizer, ForwardDifferentiate())
693+
# Solver natively supports forward differentiation.
694+
# Copy input_cache directly into model.optimizer and trigger differentiation.
695+
opt = model.optimizer
696+
T = Float64
697+
for (ci, value) in model.input_cache.parameter_constraints
698+
MOI.set(opt, ForwardConstraintSet(), ci, MOI.Parameter(value))
699+
end
700+
if model.input_cache.objective !== nothing
701+
MOI.set(opt, ForwardObjectiveFunction(), model.input_cache.objective)
702+
end
703+
for (F, S) in MOI.Utilities.DoubleDicts.nonempty_outer_keys(
704+
model.input_cache.scalar_constraints,
705+
)
706+
for (index, value) in model.input_cache.scalar_constraints[F, S]
707+
MOI.set(opt, ForwardConstraintFunction(), index, value)
708+
end
709+
end
710+
for (F, S) in MOI.Utilities.DoubleDicts.nonempty_outer_keys(
711+
model.input_cache.vector_constraints,
712+
)
713+
for (index, value) in model.input_cache.vector_constraints[F, S]
714+
MOI.set(opt, ForwardConstraintFunction(), index, value)
715+
end
716+
end
717+
MOI.set(opt, ForwardDifferentiate(), nothing)
718+
return
719+
end
676720
diff = _diff(model)
677721
MOI.set(
678722
diff,
@@ -738,7 +782,9 @@ end
738782

739783
function empty_input_sensitivities!(model::Optimizer)
740784
empty!(model.input_cache)
741-
if model.diff !== nothing
785+
if _solver_supports_differentiate(model)
786+
empty_input_sensitivities!(model.optimizer)
787+
elseif model.diff !== nothing
742788
empty_input_sensitivities!(model.diff)
743789
end
744790
return
@@ -782,7 +828,16 @@ function _instantiate_diff(model::Optimizer, constructor)
782828
return model_bridged
783829
end
784830

831+
function _solver_supports_differentiate(model::Optimizer)
832+
return MOI.supports(model.optimizer, BackwardDifferentiate()) ||
833+
MOI.supports(model.optimizer, ForwardDifferentiate())
834+
end
835+
785836
function _diff(model::Optimizer)
837+
if _solver_supports_differentiate(model)
838+
model.index_map = MOI.Utilities.identity_index_map(model.optimizer)
839+
return model.optimizer
840+
end
786841
if model.diff === nothing
787842
_check_termination_status(model)
788843
model_constructor = MOI.get(model, ModelConstructor())
@@ -837,6 +892,13 @@ end
837892
# DiffOpt attributes redirected to `diff`
838893

839894
function _checked_diff(model::Optimizer, attr::MOI.AnyAttribute, call)
895+
if _solver_supports_differentiate(model)
896+
if model.index_map === nothing
897+
model.index_map =
898+
MOI.Utilities.identity_index_map(model.optimizer)
899+
end
900+
return model.optimizer
901+
end
840902
if model.diff === nothing
841903
error("Cannot get attribute `$attr`. First call `DiffOpt.$call`.")
842904
end
@@ -1125,6 +1187,9 @@ function MOI.set(
11251187
end
11261188

11271189
function MOI.get(model::Optimizer, attr::DifferentiateTimeSec)
1190+
if _solver_supports_differentiate(model)
1191+
return MOI.get(model.optimizer, attr)
1192+
end
11281193
return MOI.get(model.diff, attr)
11291194
end
11301195

0 commit comments

Comments
 (0)