Skip to content

Commit 3b2305d

Browse files
committed
rebase on main
1 parent 816e7f6 commit 3b2305d

1 file changed

Lines changed: 97 additions & 1 deletion

File tree

corrai/optimize.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
from pymoo.core.problem import ElementwiseProblem
77
from pymoo.core.variable import Binary, Choice, Integer, Real
8-
from scipy.optimize import differential_evolution, minimize_scalar, minimize
8+
from scipy.optimize import differential_evolution, minimize_scalar, minimize, curve_fit
99

1010
from corrai.base.math import METHODS
1111
from corrai.base.model import Model
@@ -897,3 +897,99 @@ def diff_evo_minimize(
897897
rng=rng,
898898
workers=workers,
899899
)
900+
901+
def curve_fit(
902+
self,
903+
indicator_config,
904+
simulation_options=None,
905+
simulation_kwargs=None,
906+
p0=None,
907+
bounds=None,
908+
**kwargs,
909+
):
910+
"""
911+
Use non-linear least squares to fit model parameters.
912+
913+
This method wraps :func:`scipy.optimize.curve_fit` to calibrate the
914+
parameters of the underlying model so that simulated outputs best match
915+
reference data.
916+
917+
Parameters
918+
----------
919+
indicator_config : list or tuple
920+
Configuration(s) describing the outputs to match. Each config should be:
921+
(col, func, reference), where:
922+
* col : str
923+
Column name in simulation results.
924+
* func : str or Callable
925+
Aggregation function (currently unused here but kept for consistency).
926+
* reference : pd.Series
927+
Reference data to fit against.
928+
Multiple configs can be provided as a list.
929+
simulation_options : dict, optional
930+
Options passed to the simulation routine.
931+
simulation_kwargs : dict, optional
932+
Additional keyword arguments passed to the simulation.
933+
p0 : array-like, optional
934+
Initial guess for the parameters. If None, uses the midpoint of intervals.
935+
bounds : 2-tuple of array-like, optional
936+
Lower and upper bounds on parameters. If None, derived from
937+
`self.model_evaluator.intervals`.
938+
**kwargs :
939+
Additional keyword arguments passed to :func:`scipy.optimize.curve_fit`.
940+
941+
Returns
942+
-------
943+
popt : array
944+
Optimal values for the parameters.
945+
pcov : 2-D array
946+
Estimated covariance of `popt`.
947+
948+
Notes
949+
-----
950+
- The function being fitted internally runs a full model simulation.
951+
- The independent variable is a dummy index, as only the output values matter.
952+
- Multiple indicators are concatenated into a single residual vector.
953+
"""
954+
955+
if bounds is None:
956+
bounds = list(zip(*self.model_evaluator.intervals))
957+
958+
if p0 is None:
959+
p0 = [np.mean(b) for b in self.model_evaluator.intervals]
960+
961+
configs = (
962+
indicator_config
963+
if isinstance(indicator_config, list)
964+
else [indicator_config]
965+
)
966+
967+
references = [cfg[2].values for cfg in configs]
968+
reference = np.concatenate(references)
969+
970+
def wrapped_func(x, *params):
971+
self.model_evaluator.scipy_obj_function(
972+
params,
973+
configs[0],
974+
simulation_options,
975+
simulation_kwargs,
976+
)
977+
978+
res = self.model_evaluator.model.simulate(
979+
simulation_options=simulation_options,
980+
**(simulation_kwargs or {}),
981+
)
982+
983+
outputs = [res[cfg[0]].values for cfg in configs]
984+
return np.concatenate(outputs)
985+
986+
x_dummy = np.arange(len(reference))
987+
988+
return curve_fit(
989+
wrapped_func,
990+
x_dummy,
991+
reference,
992+
p0=p0,
993+
bounds=bounds,
994+
**kwargs,
995+
)

0 commit comments

Comments
 (0)