|
18 | 18 | RosenFiveParamDynamic, |
19 | 19 | ) |
20 | 20 | from corrai.base.parameter import Parameter |
| 21 | +import plotly.graph_objects as go |
| 22 | + |
21 | 23 | from corrai.optimize import ( |
22 | 24 | MixedProblem, |
23 | 25 | ModelEvaluator, |
24 | 26 | PymooModelEvaluator, |
25 | 27 | RealContinuousProblem, |
26 | 28 | SciOptimizer, |
27 | 29 | check_duplicate_params, |
| 30 | + plot_parameter_forest, |
28 | 31 | ) |
29 | 32 |
|
30 | 33 | PACKAGE_DIR = Path(__file__).parent / "TestLib" |
@@ -399,3 +402,110 @@ def test_curve_fit_simple(self): |
399 | 402 | ) |
400 | 403 |
|
401 | 404 | assert np.isclose(popt[0], 2.0, atol=1e-2) |
| 405 | + |
| 406 | + |
| 407 | +FOREST_PARAMS = [ |
| 408 | + Parameter("conductivity", interval=(0.03, 0.06), model_property="a"), |
| 409 | + Parameter("thickness", interval=(0.05, 0.30), model_property="b"), |
| 410 | + Parameter("temp_setpoint", interval=(18.0, 24.0), model_property="c"), |
| 411 | +] |
| 412 | +_OPT_DICT = {"conductivity": 0.04, "thickness": 0.12, "temp_setpoint": 21.0} |
| 413 | + |
| 414 | + |
| 415 | +class TestPlotParameterForest: |
| 416 | + def test_structure_and_normalization(self): |
| 417 | + fig = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, title="My Title") |
| 418 | + assert isinstance(fig, go.Figure) |
| 419 | + assert ( |
| 420 | + len(fig.data) == 4 |
| 421 | + ) # lines + lower ticks + upper ticks + optimal diamonds |
| 422 | + assert list(fig.data[3].x) == ["conductivity", "thickness", "temp_setpoint"] |
| 423 | + assert np.isclose( |
| 424 | + fig.data[3].y[0], 1 / 3, atol=1e-6 |
| 425 | + ) # conductivity: (0.04-0.03)/(0.06-0.03) |
| 426 | + assert all( |
| 427 | + np.isclose(v, 0.0) |
| 428 | + for v in plot_parameter_forest(FOREST_PARAMS, [0.03, 0.05, 18.0]).data[3].y |
| 429 | + ) |
| 430 | + assert all( |
| 431 | + np.isclose(v, 1.0) |
| 432 | + for v in plot_parameter_forest(FOREST_PARAMS, [0.06, 0.30, 24.0]).data[3].y |
| 433 | + ) |
| 434 | + assert fig.layout.title.text == "My Title" |
| 435 | + |
| 436 | + def test_input_types_and_choice_shown(self): |
| 437 | + assert isinstance( |
| 438 | + plot_parameter_forest(FOREST_PARAMS, [0.04, 0.12, 21.0]), go.Figure |
| 439 | + ) |
| 440 | + assert isinstance( |
| 441 | + plot_parameter_forest(FOREST_PARAMS, pd.Series(_OPT_DICT)), go.Figure |
| 442 | + ) |
| 443 | + params_with_choice = FOREST_PARAMS + [ |
| 444 | + Parameter( |
| 445 | + "algo", values=("A", "B", "C", "D"), ptype="Choice", model_property="d" |
| 446 | + ) |
| 447 | + ] |
| 448 | + opt_with_choice = {**_OPT_DICT, "algo": "A"} |
| 449 | + fig = plot_parameter_forest(params_with_choice, opt_with_choice) |
| 450 | + assert len(fig.data[-1].x) == 4 # 3 interval + 1 choice param all shown |
| 451 | + assert len(fig.data) == 5 # lines + lower + upper + choice_ticks + optimal |
| 452 | + |
| 453 | + def test_modes(self): |
| 454 | + # normalized: no text on any trace |
| 455 | + fig_norm = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="normalized") |
| 456 | + assert fig_norm.data[3].mode == "markers" |
| 457 | + assert not any(t for t in (fig_norm.data[3].text or [])) |
| 458 | + |
| 459 | + # absolute: actual values annotated on lower, upper, and optimal traces |
| 460 | + fig_abs = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="absolute") |
| 461 | + fig_abs.show() |
| 462 | + assert fig_abs.data[3].mode == "markers+text" |
| 463 | + assert fig_abs.data[1].text[0] == "0.03" # conductivity lower bound |
| 464 | + assert fig_abs.data[2].text[0] == "0.06" # conductivity upper bound |
| 465 | + assert fig_abs.data[3].text[0] == "0.04" # conductivity optimal |
| 466 | + |
| 467 | + # relative: Relative params shown as %, Absolute params fall back to actual values |
| 468 | + rel_params = [ |
| 469 | + Parameter( |
| 470 | + "mult", interval=(0.2, 1.5), relabs="Relative", model_property="x" |
| 471 | + ) |
| 472 | + ] |
| 473 | + fig_rel = plot_parameter_forest(rel_params, {"mult": 0.8}, mode="relative") |
| 474 | + assert fig_rel.data[1].text[0] == "20%" |
| 475 | + assert fig_rel.data[2].text[0] == "150%" |
| 476 | + assert fig_rel.data[3].text[0] == "80%" |
| 477 | + fig_abs_fallback = plot_parameter_forest( |
| 478 | + FOREST_PARAMS, _OPT_DICT, mode="relative" |
| 479 | + ) |
| 480 | + assert fig_abs_fallback.data[1].text[0] == "0.03" |
| 481 | + |
| 482 | + def test_layout_and_style(self): |
| 483 | + fig = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT) |
| 484 | + assert fig.data[0].line.color == "darkblue" |
| 485 | + assert fig.layout.legend.y >= 0 # legend at top |
| 486 | + assert fig.layout.autosize is True |
| 487 | + assert fig.layout.width is None |
| 488 | + |
| 489 | + def test_errors(self): |
| 490 | + with pytest.raises(ValueError, match="mode must be one of"): |
| 491 | + plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="bad") |
| 492 | + with pytest.raises(ValueError, match="Missing optimal values"): |
| 493 | + plot_parameter_forest(FOREST_PARAMS, {"conductivity": 0.04}) |
| 494 | + binary_only = [Parameter("bin", ptype="Binary", model_property="d")] |
| 495 | + with pytest.raises(ValueError, match="No parameters with interval bounds"): |
| 496 | + plot_parameter_forest(binary_only, {"bin": True}) |
| 497 | + params_with_choice = FOREST_PARAMS + [ |
| 498 | + Parameter("algo", values=("A", "B"), ptype="Choice", model_property="d") |
| 499 | + ] |
| 500 | + with pytest.raises(ValueError, match="not among choices"): |
| 501 | + plot_parameter_forest(params_with_choice, {**_OPT_DICT, "algo": "C"}) |
| 502 | + |
| 503 | + def test_evaluator_method(self): |
| 504 | + ev = ModelEvaluator(FOREST_PARAMS, X2()) |
| 505 | + fig = ev.plot_parameter_forest(_OPT_DICT, mode="absolute") |
| 506 | + fig.show() |
| 507 | + assert isinstance(fig, go.Figure) |
| 508 | + assert fig.data[3].text[0] == "0.04" |
| 509 | + pymoo_ev = PymooModelEvaluator(FOREST_PARAMS, X2()) |
| 510 | + fig2 = pymoo_ev.plot_parameter_forest([0.04, 0.12, 21.0]) |
| 511 | + assert list(fig2.data[3].x) == ["conductivity", "thickness", "temp_setpoint"] |
0 commit comments