Skip to content

Commit ccd86be

Browse files
authored
Merge pull request #132 from BuildingEnergySimulationTools/112-create-a-plot-for-identified-parameter-value
112 create a plot for identified parameter value
2 parents 7c6ee38 + e0ec205 commit ccd86be

3 files changed

Lines changed: 470 additions & 3 deletions

File tree

corrai/optimize.py

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pandas as pd
6+
import plotly.graph_objects as go
67
from pymoo.core.problem import ElementwiseProblem
78
from pymoo.core.variable import Binary, Choice, Integer, Real
89
from scipy.optimize import differential_evolution, minimize_scalar, minimize, curve_fit
@@ -279,6 +280,37 @@ def scipy_obj_function(self, x: np.ndarray, *args) -> float:
279280
def scipy_scalar_obj_function(self, x: float, *args):
280281
return self.scipy_obj_function(np.array([x]), *args)
281282

283+
def plot_parameter_forest(
284+
self,
285+
optimal_values: "dict[str, float] | list[float] | pd.Series",
286+
mode: str = "normalized",
287+
title: str = None,
288+
**plot_kwargs,
289+
) -> go.Figure:
290+
"""
291+
Forest plot of this evaluator's parameters with bounds and optimal values.
292+
293+
Delegates to the module-level :func:`plot_parameter_forest`.
294+
See that function for full documentation.
295+
296+
Parameters
297+
----------
298+
optimal_values : dict, list, or pd.Series
299+
Optimal value per parameter. When a list, order must match
300+
``self.parameters``.
301+
mode : {"normalized", "absolute", "relative"}, default "normalized"
302+
title : str, optional
303+
**plot_kwargs
304+
Forwarded to ``fig.update_layout`` / ``fig.update_traces``.
305+
"""
306+
return plot_parameter_forest(
307+
self.parameters,
308+
optimal_values,
309+
mode=mode,
310+
title=title,
311+
**plot_kwargs,
312+
)
313+
282314

283315
class PymooModelEvaluator(ModelEvaluator):
284316
"""
@@ -993,3 +1025,307 @@ def wrapped_func(x, *params):
9931025
bounds=bounds,
9941026
**kwargs,
9951027
)
1028+
1029+
1030+
_FOREST_MODES = ["normalized", "absolute", "relative"]
1031+
1032+
1033+
def _apply_figure_kwargs(fig: go.Figure, **kwargs) -> None:
1034+
for key, val in kwargs.items():
1035+
try:
1036+
fig.update_layout(**{key: val})
1037+
except ValueError:
1038+
fig.update_traces(**{key: val})
1039+
1040+
1041+
def _forest_label(value: float, relabs: str, mode: str) -> str:
1042+
"""Format a bound or optimal value for forest plot annotation."""
1043+
if mode == "normalized":
1044+
return ""
1045+
if mode == "relative" and relabs == "Relative":
1046+
return f"{value * 100:.4g}%"
1047+
return f"{value:.4g}"
1048+
1049+
1050+
def plot_parameter_forest(
1051+
parameters: list[Parameter],
1052+
optimal_values: dict[str, float] | list[float] | pd.Series,
1053+
mode: str = "normalized",
1054+
title: str = None,
1055+
template: str = "plotly_white",
1056+
**plot_kwargs,
1057+
) -> go.Figure:
1058+
"""
1059+
Forest plot of optimization parameters — parameters on the X-axis, normalized
1060+
values on the Y-axis.
1061+
1062+
Each parameter is drawn as a vertical bar spanning [0, 1] (normalized to its
1063+
own bounds) with a diamond marker at the optimal value. Parameters with
1064+
different units are thus comparable on the same scale.
1065+
1066+
How bounds are labelled depends on ``mode``:
1067+
1068+
* ``"normalized"`` — no value annotations; Y-axis ticks read 0 % … 100 %.
1069+
* ``"absolute"`` — actual lower, upper, and optimal values are shown as
1070+
text on each bar.
1071+
* ``"relative"`` — parameters with ``relabs="Relative"`` are annotated in
1072+
percent (e.g. ``interval=(0.2, 1.5)`` → ``"20 %"`` / ``"150 %"``);
1073+
parameters with ``relabs="Absolute"`` fall back to actual values.
1074+
1075+
Parameters without an ``interval`` (e.g. ``Choice``) are silently skipped.
1076+
Hover is disabled on all traces.
1077+
1078+
Parameters
1079+
----------
1080+
parameters : list of Parameter
1081+
Parameters defining the search space.
1082+
optimal_values : dict, list, or pd.Series
1083+
Optimal value per parameter after optimisation. When a list or array,
1084+
order must match ``parameters``.
1085+
mode : {"normalized", "absolute", "relative"}, default "normalized"
1086+
Annotation style (see above).
1087+
title : str, optional
1088+
Plot title.
1089+
**plot_kwargs
1090+
Forwarded to ``fig.update_layout`` or ``fig.update_traces``.
1091+
1092+
Returns
1093+
-------
1094+
plotly.graph_objects.Figure
1095+
1096+
Examples
1097+
--------
1098+
>>> from corrai.base.parameter import Parameter
1099+
>>> from corrai.optimize import plot_parameter_forest
1100+
>>> params = [
1101+
... Parameter("conductivity", interval=(0.03, 0.06), model_property="x"),
1102+
... Parameter("thickness", interval=(0.05, 0.30), model_property="y"),
1103+
... Parameter("temp_setpoint", interval=(18.0, 24.0), model_property="z"),
1104+
... ]
1105+
>>> fig = plot_parameter_forest(
1106+
... params,
1107+
... {"conductivity": 0.04, "thickness": 0.12, "temp_setpoint": 21.0},
1108+
... mode="absolute",
1109+
... )
1110+
"""
1111+
if mode not in _FOREST_MODES:
1112+
raise ValueError(f"mode must be one of {_FOREST_MODES}, got {mode!r}")
1113+
1114+
# Include continuous (interval) and categorical (values/Choice) params; skip Binary
1115+
params = [p for p in parameters if p.interval is not None or p.values is not None]
1116+
if not params:
1117+
raise ValueError("No parameters with interval bounds found.")
1118+
1119+
all_param_names = {p.name for p in params}
1120+
if isinstance(optimal_values, (list, np.ndarray)):
1121+
opt_dict = {
1122+
p.name: v
1123+
for p, v in zip(parameters, optimal_values)
1124+
if p.interval is not None or p.values is not None
1125+
}
1126+
elif isinstance(optimal_values, pd.Series):
1127+
opt_dict = {k: v for k, v in optimal_values.items() if k in all_param_names}
1128+
else:
1129+
opt_dict = {k: v for k, v in optimal_values.items() if k in all_param_names}
1130+
1131+
missing = [p.name for p in params if p.name not in opt_dict]
1132+
if missing:
1133+
raise ValueError(f"Missing optimal values for parameters: {missing}")
1134+
1135+
names = [p.name for p in params]
1136+
interval_params = [p for p in params if p.interval is not None]
1137+
choice_params = [p for p in params if p.values is not None]
1138+
1139+
# --- Normalized optimal positions
1140+
opt_norms: dict[str, float] = {}
1141+
for p in params:
1142+
if p.interval is not None:
1143+
lo, hi = p.interval
1144+
v = float(opt_dict[p.name])
1145+
opt_norms[p.name] = (v - lo) / (hi - lo) if hi != lo else 0.5
1146+
else:
1147+
n = len(p.values)
1148+
positions = [i / (n - 1) for i in range(n)] if n > 1 else [0.5]
1149+
try:
1150+
idx = list(p.values).index(opt_dict[p.name])
1151+
except ValueError:
1152+
raise ValueError(
1153+
f"Optimal value {opt_dict[p.name]!r} not among choices "
1154+
f"{p.values} for parameter {p.name!r}"
1155+
)
1156+
opt_norms[p.name] = positions[idx]
1157+
1158+
# --- Labels
1159+
annotate = mode != "normalized"
1160+
lower_texts = {
1161+
p.name: _forest_label(p.interval[0], p.relabs, mode) for p in interval_params
1162+
}
1163+
upper_texts = {
1164+
p.name: _forest_label(p.interval[1], p.relabs, mode) for p in interval_params
1165+
}
1166+
# Optimal text: mode-aware for interval; empty for choice (tick labels already mark each position)
1167+
all_opt_texts = {
1168+
p.name: (
1169+
_forest_label(float(opt_dict[p.name]), p.relabs, mode)
1170+
if p.interval is not None
1171+
else ""
1172+
)
1173+
for p in params
1174+
}
1175+
annotate_opt = annotate
1176+
1177+
_bar_color = "darkblue"
1178+
fig = go.Figure()
1179+
1180+
# Trace: vertical lines for all params
1181+
x_lines: list[str | None] = []
1182+
y_lines: list[float | None] = []
1183+
for name in names:
1184+
x_lines.extend([name, name, None])
1185+
y_lines.extend([0.0, 1.0, None])
1186+
fig.add_trace(
1187+
go.Scatter(
1188+
x=x_lines,
1189+
y=y_lines,
1190+
mode="lines",
1191+
line=dict(color=_bar_color, width=2),
1192+
showlegend=False,
1193+
hoverinfo="skip",
1194+
)
1195+
)
1196+
1197+
# Trace: interval lower bound ticks (y=0, text below)
1198+
if interval_params:
1199+
inames = [p.name for p in interval_params]
1200+
fig.add_trace(
1201+
go.Scatter(
1202+
x=inames,
1203+
y=[0.0] * len(inames),
1204+
mode="markers+text" if annotate else "markers",
1205+
marker=dict(
1206+
symbol="line-ew-open",
1207+
size=14,
1208+
color=_bar_color,
1209+
line=dict(width=2, color=_bar_color),
1210+
),
1211+
text=[lower_texts[n] for n in inames],
1212+
textposition="bottom center",
1213+
showlegend=False,
1214+
hoverinfo="skip",
1215+
)
1216+
)
1217+
1218+
# Trace: interval upper bound ticks (y=1, text above)
1219+
if interval_params:
1220+
inames = [p.name for p in interval_params]
1221+
fig.add_trace(
1222+
go.Scatter(
1223+
x=inames,
1224+
y=[1.0] * len(inames),
1225+
mode="markers+text" if annotate else "markers",
1226+
marker=dict(
1227+
symbol="line-ew-open",
1228+
size=14,
1229+
color=_bar_color,
1230+
line=dict(width=2, color=_bar_color),
1231+
),
1232+
text=[upper_texts[n] for n in inames],
1233+
textposition="top center",
1234+
showlegend=False,
1235+
hoverinfo="skip",
1236+
)
1237+
)
1238+
1239+
# Trace: choice tick marks — one entry per choice value, always labelled
1240+
if choice_params:
1241+
cx: list[str] = []
1242+
cy: list[float] = []
1243+
ctexts: list[str] = []
1244+
ctextpositions: list[str] = []
1245+
for p in choice_params:
1246+
n = len(p.values)
1247+
positions = [i / (n - 1) for i in range(n)] if n > 1 else [0.5]
1248+
for i, (val, pos) in enumerate(zip(p.values, positions)):
1249+
cx.append(p.name)
1250+
cy.append(pos)
1251+
ctexts.append(str(val))
1252+
if n == 1:
1253+
ctextpositions.append("top center")
1254+
elif i == 0:
1255+
ctextpositions.append("bottom center")
1256+
elif i == n - 1:
1257+
ctextpositions.append("top center")
1258+
else:
1259+
ctextpositions.append("middle right")
1260+
fig.add_trace(
1261+
go.Scatter(
1262+
x=cx,
1263+
y=cy,
1264+
mode="markers+text",
1265+
marker=dict(
1266+
symbol="line-ew-open",
1267+
size=14,
1268+
color=_bar_color,
1269+
line=dict(width=2, color=_bar_color),
1270+
),
1271+
text=ctexts,
1272+
textposition=ctextpositions,
1273+
showlegend=False,
1274+
hoverinfo="skip",
1275+
)
1276+
)
1277+
1278+
# Trace: optimal diamonds — always last
1279+
fig.add_trace(
1280+
go.Scatter(
1281+
x=names,
1282+
y=[opt_norms[n] for n in names],
1283+
mode="markers+text" if annotate_opt else "markers",
1284+
name="Optimal",
1285+
marker=dict(
1286+
symbol="diamond",
1287+
size=13,
1288+
color="orange",
1289+
line=dict(width=1.5, color="darkorange"),
1290+
),
1291+
text=[all_opt_texts[n] for n in names],
1292+
textposition="middle right",
1293+
showlegend=True,
1294+
hoverinfo="skip",
1295+
)
1296+
)
1297+
1298+
# Lower text is "bottom center" → needs a bit of space below 0
1299+
y_range = [-0.05, 1.1] if mode == "normalized" else [-0.18, 1.25]
1300+
if mode == "normalized":
1301+
y_tickvals = [0.0, 0.25, 0.5, 0.75, 1.0]
1302+
y_ticktext = ["Lower (0%)", "25%", "50%", "75%", "Upper (100%)"]
1303+
title_y = "Normalized position with bounds"
1304+
else:
1305+
y_tickvals = [0.0, 1.0]
1306+
y_ticktext = ["Lower bound", "Upper bound"]
1307+
title_y = "Position with bounds"
1308+
1309+
b_margin = 100 if len(names) > 5 else 70
1310+
1311+
fig.update_layout(
1312+
title=title,
1313+
xaxis=dict(
1314+
tickangle=-30 if len(names) > 5 else 0,
1315+
),
1316+
yaxis=dict(
1317+
title=title_y,
1318+
range=y_range,
1319+
tickvals=y_tickvals,
1320+
ticktext=y_ticktext,
1321+
showgrid=True,
1322+
zeroline=False,
1323+
),
1324+
template=template,
1325+
legend=dict(orientation="h", yanchor="bottom", y=1.0, xanchor="right", x=1),
1326+
autosize=True,
1327+
margin=dict(l=70, r=30, t=40, b=b_margin),
1328+
)
1329+
1330+
_apply_figure_kwargs(fig, **plot_kwargs)
1331+
return fig

0 commit comments

Comments
 (0)