|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pandas as pd |
| 6 | +import plotly.graph_objects as go |
6 | 7 | from pymoo.core.problem import ElementwiseProblem |
7 | 8 | from pymoo.core.variable import Binary, Choice, Integer, Real |
8 | 9 | from scipy.optimize import differential_evolution, minimize_scalar, minimize, curve_fit |
@@ -279,6 +280,37 @@ def scipy_obj_function(self, x: np.ndarray, *args) -> float: |
279 | 280 | def scipy_scalar_obj_function(self, x: float, *args): |
280 | 281 | return self.scipy_obj_function(np.array([x]), *args) |
281 | 282 |
|
| 283 | + def plot_parameter_forest( |
| 284 | + self, |
| 285 | + optimal_values: "dict[str, float] | list[float] | pd.Series", |
| 286 | + mode: str = "normalized", |
| 287 | + title: str = None, |
| 288 | + **plot_kwargs, |
| 289 | + ) -> go.Figure: |
| 290 | + """ |
| 291 | + Forest plot of this evaluator's parameters with bounds and optimal values. |
| 292 | +
|
| 293 | + Delegates to the module-level :func:`plot_parameter_forest`. |
| 294 | + See that function for full documentation. |
| 295 | +
|
| 296 | + Parameters |
| 297 | + ---------- |
| 298 | + optimal_values : dict, list, or pd.Series |
| 299 | + Optimal value per parameter. When a list, order must match |
| 300 | + ``self.parameters``. |
| 301 | + mode : {"normalized", "absolute", "relative"}, default "normalized" |
| 302 | + title : str, optional |
| 303 | + **plot_kwargs |
| 304 | + Forwarded to ``fig.update_layout`` / ``fig.update_traces``. |
| 305 | + """ |
| 306 | + return plot_parameter_forest( |
| 307 | + self.parameters, |
| 308 | + optimal_values, |
| 309 | + mode=mode, |
| 310 | + title=title, |
| 311 | + **plot_kwargs, |
| 312 | + ) |
| 313 | + |
282 | 314 |
|
283 | 315 | class PymooModelEvaluator(ModelEvaluator): |
284 | 316 | """ |
@@ -993,3 +1025,307 @@ def wrapped_func(x, *params): |
993 | 1025 | bounds=bounds, |
994 | 1026 | **kwargs, |
995 | 1027 | ) |
| 1028 | + |
| 1029 | + |
| 1030 | +_FOREST_MODES = ["normalized", "absolute", "relative"] |
| 1031 | + |
| 1032 | + |
| 1033 | +def _apply_figure_kwargs(fig: go.Figure, **kwargs) -> None: |
| 1034 | + for key, val in kwargs.items(): |
| 1035 | + try: |
| 1036 | + fig.update_layout(**{key: val}) |
| 1037 | + except ValueError: |
| 1038 | + fig.update_traces(**{key: val}) |
| 1039 | + |
| 1040 | + |
| 1041 | +def _forest_label(value: float, relabs: str, mode: str) -> str: |
| 1042 | + """Format a bound or optimal value for forest plot annotation.""" |
| 1043 | + if mode == "normalized": |
| 1044 | + return "" |
| 1045 | + if mode == "relative" and relabs == "Relative": |
| 1046 | + return f"{value * 100:.4g}%" |
| 1047 | + return f"{value:.4g}" |
| 1048 | + |
| 1049 | + |
| 1050 | +def plot_parameter_forest( |
| 1051 | + parameters: list[Parameter], |
| 1052 | + optimal_values: dict[str, float] | list[float] | pd.Series, |
| 1053 | + mode: str = "normalized", |
| 1054 | + title: str = None, |
| 1055 | + template: str = "plotly_white", |
| 1056 | + **plot_kwargs, |
| 1057 | +) -> go.Figure: |
| 1058 | + """ |
| 1059 | + Forest plot of optimization parameters — parameters on the X-axis, normalized |
| 1060 | + values on the Y-axis. |
| 1061 | +
|
| 1062 | + Each parameter is drawn as a vertical bar spanning [0, 1] (normalized to its |
| 1063 | + own bounds) with a diamond marker at the optimal value. Parameters with |
| 1064 | + different units are thus comparable on the same scale. |
| 1065 | +
|
| 1066 | + How bounds are labelled depends on ``mode``: |
| 1067 | +
|
| 1068 | + * ``"normalized"`` — no value annotations; Y-axis ticks read 0 % … 100 %. |
| 1069 | + * ``"absolute"`` — actual lower, upper, and optimal values are shown as |
| 1070 | + text on each bar. |
| 1071 | + * ``"relative"`` — parameters with ``relabs="Relative"`` are annotated in |
| 1072 | + percent (e.g. ``interval=(0.2, 1.5)`` → ``"20 %"`` / ``"150 %"``); |
| 1073 | + parameters with ``relabs="Absolute"`` fall back to actual values. |
| 1074 | +
|
| 1075 | + Parameters without an ``interval`` (e.g. ``Choice``) are silently skipped. |
| 1076 | + Hover is disabled on all traces. |
| 1077 | +
|
| 1078 | + Parameters |
| 1079 | + ---------- |
| 1080 | + parameters : list of Parameter |
| 1081 | + Parameters defining the search space. |
| 1082 | + optimal_values : dict, list, or pd.Series |
| 1083 | + Optimal value per parameter after optimisation. When a list or array, |
| 1084 | + order must match ``parameters``. |
| 1085 | + mode : {"normalized", "absolute", "relative"}, default "normalized" |
| 1086 | + Annotation style (see above). |
| 1087 | + title : str, optional |
| 1088 | + Plot title. |
| 1089 | + **plot_kwargs |
| 1090 | + Forwarded to ``fig.update_layout`` or ``fig.update_traces``. |
| 1091 | +
|
| 1092 | + Returns |
| 1093 | + ------- |
| 1094 | + plotly.graph_objects.Figure |
| 1095 | +
|
| 1096 | + Examples |
| 1097 | + -------- |
| 1098 | + >>> from corrai.base.parameter import Parameter |
| 1099 | + >>> from corrai.optimize import plot_parameter_forest |
| 1100 | + >>> params = [ |
| 1101 | + ... Parameter("conductivity", interval=(0.03, 0.06), model_property="x"), |
| 1102 | + ... Parameter("thickness", interval=(0.05, 0.30), model_property="y"), |
| 1103 | + ... Parameter("temp_setpoint", interval=(18.0, 24.0), model_property="z"), |
| 1104 | + ... ] |
| 1105 | + >>> fig = plot_parameter_forest( |
| 1106 | + ... params, |
| 1107 | + ... {"conductivity": 0.04, "thickness": 0.12, "temp_setpoint": 21.0}, |
| 1108 | + ... mode="absolute", |
| 1109 | + ... ) |
| 1110 | + """ |
| 1111 | + if mode not in _FOREST_MODES: |
| 1112 | + raise ValueError(f"mode must be one of {_FOREST_MODES}, got {mode!r}") |
| 1113 | + |
| 1114 | + # Include continuous (interval) and categorical (values/Choice) params; skip Binary |
| 1115 | + params = [p for p in parameters if p.interval is not None or p.values is not None] |
| 1116 | + if not params: |
| 1117 | + raise ValueError("No parameters with interval bounds found.") |
| 1118 | + |
| 1119 | + all_param_names = {p.name for p in params} |
| 1120 | + if isinstance(optimal_values, (list, np.ndarray)): |
| 1121 | + opt_dict = { |
| 1122 | + p.name: v |
| 1123 | + for p, v in zip(parameters, optimal_values) |
| 1124 | + if p.interval is not None or p.values is not None |
| 1125 | + } |
| 1126 | + elif isinstance(optimal_values, pd.Series): |
| 1127 | + opt_dict = {k: v for k, v in optimal_values.items() if k in all_param_names} |
| 1128 | + else: |
| 1129 | + opt_dict = {k: v for k, v in optimal_values.items() if k in all_param_names} |
| 1130 | + |
| 1131 | + missing = [p.name for p in params if p.name not in opt_dict] |
| 1132 | + if missing: |
| 1133 | + raise ValueError(f"Missing optimal values for parameters: {missing}") |
| 1134 | + |
| 1135 | + names = [p.name for p in params] |
| 1136 | + interval_params = [p for p in params if p.interval is not None] |
| 1137 | + choice_params = [p for p in params if p.values is not None] |
| 1138 | + |
| 1139 | + # --- Normalized optimal positions |
| 1140 | + opt_norms: dict[str, float] = {} |
| 1141 | + for p in params: |
| 1142 | + if p.interval is not None: |
| 1143 | + lo, hi = p.interval |
| 1144 | + v = float(opt_dict[p.name]) |
| 1145 | + opt_norms[p.name] = (v - lo) / (hi - lo) if hi != lo else 0.5 |
| 1146 | + else: |
| 1147 | + n = len(p.values) |
| 1148 | + positions = [i / (n - 1) for i in range(n)] if n > 1 else [0.5] |
| 1149 | + try: |
| 1150 | + idx = list(p.values).index(opt_dict[p.name]) |
| 1151 | + except ValueError: |
| 1152 | + raise ValueError( |
| 1153 | + f"Optimal value {opt_dict[p.name]!r} not among choices " |
| 1154 | + f"{p.values} for parameter {p.name!r}" |
| 1155 | + ) |
| 1156 | + opt_norms[p.name] = positions[idx] |
| 1157 | + |
| 1158 | + # --- Labels |
| 1159 | + annotate = mode != "normalized" |
| 1160 | + lower_texts = { |
| 1161 | + p.name: _forest_label(p.interval[0], p.relabs, mode) for p in interval_params |
| 1162 | + } |
| 1163 | + upper_texts = { |
| 1164 | + p.name: _forest_label(p.interval[1], p.relabs, mode) for p in interval_params |
| 1165 | + } |
| 1166 | + # Optimal text: mode-aware for interval; empty for choice (tick labels already mark each position) |
| 1167 | + all_opt_texts = { |
| 1168 | + p.name: ( |
| 1169 | + _forest_label(float(opt_dict[p.name]), p.relabs, mode) |
| 1170 | + if p.interval is not None |
| 1171 | + else "" |
| 1172 | + ) |
| 1173 | + for p in params |
| 1174 | + } |
| 1175 | + annotate_opt = annotate |
| 1176 | + |
| 1177 | + _bar_color = "darkblue" |
| 1178 | + fig = go.Figure() |
| 1179 | + |
| 1180 | + # Trace: vertical lines for all params |
| 1181 | + x_lines: list[str | None] = [] |
| 1182 | + y_lines: list[float | None] = [] |
| 1183 | + for name in names: |
| 1184 | + x_lines.extend([name, name, None]) |
| 1185 | + y_lines.extend([0.0, 1.0, None]) |
| 1186 | + fig.add_trace( |
| 1187 | + go.Scatter( |
| 1188 | + x=x_lines, |
| 1189 | + y=y_lines, |
| 1190 | + mode="lines", |
| 1191 | + line=dict(color=_bar_color, width=2), |
| 1192 | + showlegend=False, |
| 1193 | + hoverinfo="skip", |
| 1194 | + ) |
| 1195 | + ) |
| 1196 | + |
| 1197 | + # Trace: interval lower bound ticks (y=0, text below) |
| 1198 | + if interval_params: |
| 1199 | + inames = [p.name for p in interval_params] |
| 1200 | + fig.add_trace( |
| 1201 | + go.Scatter( |
| 1202 | + x=inames, |
| 1203 | + y=[0.0] * len(inames), |
| 1204 | + mode="markers+text" if annotate else "markers", |
| 1205 | + marker=dict( |
| 1206 | + symbol="line-ew-open", |
| 1207 | + size=14, |
| 1208 | + color=_bar_color, |
| 1209 | + line=dict(width=2, color=_bar_color), |
| 1210 | + ), |
| 1211 | + text=[lower_texts[n] for n in inames], |
| 1212 | + textposition="bottom center", |
| 1213 | + showlegend=False, |
| 1214 | + hoverinfo="skip", |
| 1215 | + ) |
| 1216 | + ) |
| 1217 | + |
| 1218 | + # Trace: interval upper bound ticks (y=1, text above) |
| 1219 | + if interval_params: |
| 1220 | + inames = [p.name for p in interval_params] |
| 1221 | + fig.add_trace( |
| 1222 | + go.Scatter( |
| 1223 | + x=inames, |
| 1224 | + y=[1.0] * len(inames), |
| 1225 | + mode="markers+text" if annotate else "markers", |
| 1226 | + marker=dict( |
| 1227 | + symbol="line-ew-open", |
| 1228 | + size=14, |
| 1229 | + color=_bar_color, |
| 1230 | + line=dict(width=2, color=_bar_color), |
| 1231 | + ), |
| 1232 | + text=[upper_texts[n] for n in inames], |
| 1233 | + textposition="top center", |
| 1234 | + showlegend=False, |
| 1235 | + hoverinfo="skip", |
| 1236 | + ) |
| 1237 | + ) |
| 1238 | + |
| 1239 | + # Trace: choice tick marks — one entry per choice value, always labelled |
| 1240 | + if choice_params: |
| 1241 | + cx: list[str] = [] |
| 1242 | + cy: list[float] = [] |
| 1243 | + ctexts: list[str] = [] |
| 1244 | + ctextpositions: list[str] = [] |
| 1245 | + for p in choice_params: |
| 1246 | + n = len(p.values) |
| 1247 | + positions = [i / (n - 1) for i in range(n)] if n > 1 else [0.5] |
| 1248 | + for i, (val, pos) in enumerate(zip(p.values, positions)): |
| 1249 | + cx.append(p.name) |
| 1250 | + cy.append(pos) |
| 1251 | + ctexts.append(str(val)) |
| 1252 | + if n == 1: |
| 1253 | + ctextpositions.append("top center") |
| 1254 | + elif i == 0: |
| 1255 | + ctextpositions.append("bottom center") |
| 1256 | + elif i == n - 1: |
| 1257 | + ctextpositions.append("top center") |
| 1258 | + else: |
| 1259 | + ctextpositions.append("middle right") |
| 1260 | + fig.add_trace( |
| 1261 | + go.Scatter( |
| 1262 | + x=cx, |
| 1263 | + y=cy, |
| 1264 | + mode="markers+text", |
| 1265 | + marker=dict( |
| 1266 | + symbol="line-ew-open", |
| 1267 | + size=14, |
| 1268 | + color=_bar_color, |
| 1269 | + line=dict(width=2, color=_bar_color), |
| 1270 | + ), |
| 1271 | + text=ctexts, |
| 1272 | + textposition=ctextpositions, |
| 1273 | + showlegend=False, |
| 1274 | + hoverinfo="skip", |
| 1275 | + ) |
| 1276 | + ) |
| 1277 | + |
| 1278 | + # Trace: optimal diamonds — always last |
| 1279 | + fig.add_trace( |
| 1280 | + go.Scatter( |
| 1281 | + x=names, |
| 1282 | + y=[opt_norms[n] for n in names], |
| 1283 | + mode="markers+text" if annotate_opt else "markers", |
| 1284 | + name="Optimal", |
| 1285 | + marker=dict( |
| 1286 | + symbol="diamond", |
| 1287 | + size=13, |
| 1288 | + color="orange", |
| 1289 | + line=dict(width=1.5, color="darkorange"), |
| 1290 | + ), |
| 1291 | + text=[all_opt_texts[n] for n in names], |
| 1292 | + textposition="middle right", |
| 1293 | + showlegend=True, |
| 1294 | + hoverinfo="skip", |
| 1295 | + ) |
| 1296 | + ) |
| 1297 | + |
| 1298 | + # Lower text is "bottom center" → needs a bit of space below 0 |
| 1299 | + y_range = [-0.05, 1.1] if mode == "normalized" else [-0.18, 1.25] |
| 1300 | + if mode == "normalized": |
| 1301 | + y_tickvals = [0.0, 0.25, 0.5, 0.75, 1.0] |
| 1302 | + y_ticktext = ["Lower (0%)", "25%", "50%", "75%", "Upper (100%)"] |
| 1303 | + title_y = "Normalized position with bounds" |
| 1304 | + else: |
| 1305 | + y_tickvals = [0.0, 1.0] |
| 1306 | + y_ticktext = ["Lower bound", "Upper bound"] |
| 1307 | + title_y = "Position with bounds" |
| 1308 | + |
| 1309 | + b_margin = 100 if len(names) > 5 else 70 |
| 1310 | + |
| 1311 | + fig.update_layout( |
| 1312 | + title=title, |
| 1313 | + xaxis=dict( |
| 1314 | + tickangle=-30 if len(names) > 5 else 0, |
| 1315 | + ), |
| 1316 | + yaxis=dict( |
| 1317 | + title=title_y, |
| 1318 | + range=y_range, |
| 1319 | + tickvals=y_tickvals, |
| 1320 | + ticktext=y_ticktext, |
| 1321 | + showgrid=True, |
| 1322 | + zeroline=False, |
| 1323 | + ), |
| 1324 | + template=template, |
| 1325 | + legend=dict(orientation="h", yanchor="bottom", y=1.0, xanchor="right", x=1), |
| 1326 | + autosize=True, |
| 1327 | + margin=dict(l=70, r=30, t=40, b=b_margin), |
| 1328 | + ) |
| 1329 | + |
| 1330 | + _apply_figure_kwargs(fig, **plot_kwargs) |
| 1331 | + return fig |
0 commit comments