@@ -562,6 +562,22 @@ function reverse_differentiate!(model::Optimizer)
562562 " Set `DiffOpt.AllowObjectiveAndSolutionInput()` to `true` to silence this warning."
563563 end
564564 end
565+ if MOI. supports (model. optimizer, BackwardDifferentiate ())
566+ # Solver natively supports backward differentiation.
567+ # Copy input_cache directly into model.optimizer and trigger differentiation.
568+ opt = model. optimizer
569+ for (vi, value) in model. input_cache. dx
570+ MOI. set (opt, ReverseVariablePrimal (), vi, value)
571+ end
572+ for (ci, value) in model. input_cache. dy
573+ MOI. set (opt, ReverseConstraintDual (), ci, value)
574+ end
575+ if ! iszero (model. input_cache. dobj)
576+ MOI. set (opt, ReverseObjectiveSensitivity (), model. input_cache. dobj)
577+ end
578+ MOI. set (opt, BackwardDifferentiate (), nothing )
579+ return
580+ end
565581 diff = _diff (model)
566582 MOI. set (
567583 diff,
@@ -673,6 +689,34 @@ function forward_differentiate!(model::Optimizer)
673689 " Trying to compute the forward differentiation on a model with termination status $(st) " ,
674690 )
675691 end
692+ if MOI. supports (model. optimizer, ForwardDifferentiate ())
693+ # Solver natively supports forward differentiation.
694+ # Copy input_cache directly into model.optimizer and trigger differentiation.
695+ opt = model. optimizer
696+ T = Float64
697+ for (ci, value) in model. input_cache. parameter_constraints
698+ MOI. set (opt, ForwardConstraintSet (), ci, MOI. Parameter (value))
699+ end
700+ if model. input_cache. objective != = nothing
701+ MOI. set (opt, ForwardObjectiveFunction (), model. input_cache. objective)
702+ end
703+ for (F, S) in MOI. Utilities. DoubleDicts. nonempty_outer_keys (
704+ model. input_cache. scalar_constraints,
705+ )
706+ for (index, value) in model. input_cache. scalar_constraints[F, S]
707+ MOI. set (opt, ForwardConstraintFunction (), index, value)
708+ end
709+ end
710+ for (F, S) in MOI. Utilities. DoubleDicts. nonempty_outer_keys (
711+ model. input_cache. vector_constraints,
712+ )
713+ for (index, value) in model. input_cache. vector_constraints[F, S]
714+ MOI. set (opt, ForwardConstraintFunction (), index, value)
715+ end
716+ end
717+ MOI. set (opt, ForwardDifferentiate (), nothing )
718+ return
719+ end
676720 diff = _diff (model)
677721 MOI. set (
678722 diff,
738782
739783function empty_input_sensitivities! (model:: Optimizer )
740784 empty! (model. input_cache)
741- if model. diff != = nothing
785+ if _solver_supports_differentiate (model)
786+ empty_input_sensitivities! (model. optimizer)
787+ elseif model. diff != = nothing
742788 empty_input_sensitivities! (model. diff)
743789 end
744790 return
@@ -782,7 +828,16 @@ function _instantiate_diff(model::Optimizer, constructor)
782828 return model_bridged
783829end
784830
831+ function _solver_supports_differentiate (model:: Optimizer )
832+ return MOI. supports (model. optimizer, BackwardDifferentiate ()) ||
833+ MOI. supports (model. optimizer, ForwardDifferentiate ())
834+ end
835+
785836function _diff (model:: Optimizer )
837+ if _solver_supports_differentiate (model)
838+ model. index_map = MOI. Utilities. identity_index_map (model. optimizer)
839+ return model. optimizer
840+ end
786841 if model. diff === nothing
787842 _check_termination_status (model)
788843 model_constructor = MOI. get (model, ModelConstructor ())
837892# DiffOpt attributes redirected to `diff`
838893
839894function _checked_diff (model:: Optimizer , attr:: MOI.AnyAttribute , call)
895+ if _solver_supports_differentiate (model)
896+ if model. index_map === nothing
897+ model. index_map =
898+ MOI. Utilities. identity_index_map (model. optimizer)
899+ end
900+ return model. optimizer
901+ end
840902 if model. diff === nothing
841903 error (" Cannot get attribute `$attr `. First call `DiffOpt.$call `." )
842904 end
@@ -1125,6 +1187,9 @@ function MOI.set(
11251187end
11261188
11271189function MOI. get (model:: Optimizer , attr:: DifferentiateTimeSec )
1190+ if _solver_supports_differentiate (model)
1191+ return MOI. get (model. optimizer, attr)
1192+ end
11281193 return MOI. get (model. diff, attr)
11291194end
11301195
0 commit comments