Skip to content

Commit 1733fde

Browse files
committed
✅ tests of plot_parameter_forest
1 parent 12d7730 commit 1733fde

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

tests/test_optimize.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
RosenFiveParamDynamic,
1919
)
2020
from corrai.base.parameter import Parameter
21+
import plotly.graph_objects as go
22+
2123
from corrai.optimize import (
2224
MixedProblem,
2325
ModelEvaluator,
2426
PymooModelEvaluator,
2527
RealContinuousProblem,
2628
SciOptimizer,
2729
check_duplicate_params,
30+
plot_parameter_forest,
2831
)
2932

3033
PACKAGE_DIR = Path(__file__).parent / "TestLib"
@@ -399,3 +402,110 @@ def test_curve_fit_simple(self):
399402
)
400403

401404
assert np.isclose(popt[0], 2.0, atol=1e-2)
405+
406+
407+
FOREST_PARAMS = [
408+
Parameter("conductivity", interval=(0.03, 0.06), model_property="a"),
409+
Parameter("thickness", interval=(0.05, 0.30), model_property="b"),
410+
Parameter("temp_setpoint", interval=(18.0, 24.0), model_property="c"),
411+
]
412+
_OPT_DICT = {"conductivity": 0.04, "thickness": 0.12, "temp_setpoint": 21.0}
413+
414+
415+
class TestPlotParameterForest:
416+
def test_structure_and_normalization(self):
417+
fig = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, title="My Title")
418+
assert isinstance(fig, go.Figure)
419+
assert (
420+
len(fig.data) == 4
421+
) # lines + lower ticks + upper ticks + optimal diamonds
422+
assert list(fig.data[3].x) == ["conductivity", "thickness", "temp_setpoint"]
423+
assert np.isclose(
424+
fig.data[3].y[0], 1 / 3, atol=1e-6
425+
) # conductivity: (0.04-0.03)/(0.06-0.03)
426+
assert all(
427+
np.isclose(v, 0.0)
428+
for v in plot_parameter_forest(FOREST_PARAMS, [0.03, 0.05, 18.0]).data[3].y
429+
)
430+
assert all(
431+
np.isclose(v, 1.0)
432+
for v in plot_parameter_forest(FOREST_PARAMS, [0.06, 0.30, 24.0]).data[3].y
433+
)
434+
assert fig.layout.title.text == "My Title"
435+
436+
def test_input_types_and_choice_shown(self):
437+
assert isinstance(
438+
plot_parameter_forest(FOREST_PARAMS, [0.04, 0.12, 21.0]), go.Figure
439+
)
440+
assert isinstance(
441+
plot_parameter_forest(FOREST_PARAMS, pd.Series(_OPT_DICT)), go.Figure
442+
)
443+
params_with_choice = FOREST_PARAMS + [
444+
Parameter(
445+
"algo", values=("A", "B", "C", "D"), ptype="Choice", model_property="d"
446+
)
447+
]
448+
opt_with_choice = {**_OPT_DICT, "algo": "A"}
449+
fig = plot_parameter_forest(params_with_choice, opt_with_choice)
450+
assert len(fig.data[-1].x) == 4 # 3 interval + 1 choice param all shown
451+
assert len(fig.data) == 5 # lines + lower + upper + choice_ticks + optimal
452+
453+
def test_modes(self):
454+
# normalized: no text on any trace
455+
fig_norm = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="normalized")
456+
assert fig_norm.data[3].mode == "markers"
457+
assert not any(t for t in (fig_norm.data[3].text or []))
458+
459+
# absolute: actual values annotated on lower, upper, and optimal traces
460+
fig_abs = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="absolute")
461+
fig_abs.show()
462+
assert fig_abs.data[3].mode == "markers+text"
463+
assert fig_abs.data[1].text[0] == "0.03" # conductivity lower bound
464+
assert fig_abs.data[2].text[0] == "0.06" # conductivity upper bound
465+
assert fig_abs.data[3].text[0] == "0.04" # conductivity optimal
466+
467+
# relative: Relative params shown as %, Absolute params fall back to actual values
468+
rel_params = [
469+
Parameter(
470+
"mult", interval=(0.2, 1.5), relabs="Relative", model_property="x"
471+
)
472+
]
473+
fig_rel = plot_parameter_forest(rel_params, {"mult": 0.8}, mode="relative")
474+
assert fig_rel.data[1].text[0] == "20%"
475+
assert fig_rel.data[2].text[0] == "150%"
476+
assert fig_rel.data[3].text[0] == "80%"
477+
fig_abs_fallback = plot_parameter_forest(
478+
FOREST_PARAMS, _OPT_DICT, mode="relative"
479+
)
480+
assert fig_abs_fallback.data[1].text[0] == "0.03"
481+
482+
def test_layout_and_style(self):
483+
fig = plot_parameter_forest(FOREST_PARAMS, _OPT_DICT)
484+
assert fig.data[0].line.color == "darkblue"
485+
assert fig.layout.legend.y >= 0 # legend at top
486+
assert fig.layout.autosize is True
487+
assert fig.layout.width is None
488+
489+
def test_errors(self):
490+
with pytest.raises(ValueError, match="mode must be one of"):
491+
plot_parameter_forest(FOREST_PARAMS, _OPT_DICT, mode="bad")
492+
with pytest.raises(ValueError, match="Missing optimal values"):
493+
plot_parameter_forest(FOREST_PARAMS, {"conductivity": 0.04})
494+
binary_only = [Parameter("bin", ptype="Binary", model_property="d")]
495+
with pytest.raises(ValueError, match="No parameters with interval bounds"):
496+
plot_parameter_forest(binary_only, {"bin": True})
497+
params_with_choice = FOREST_PARAMS + [
498+
Parameter("algo", values=("A", "B"), ptype="Choice", model_property="d")
499+
]
500+
with pytest.raises(ValueError, match="not among choices"):
501+
plot_parameter_forest(params_with_choice, {**_OPT_DICT, "algo": "C"})
502+
503+
def test_evaluator_method(self):
504+
ev = ModelEvaluator(FOREST_PARAMS, X2())
505+
fig = ev.plot_parameter_forest(_OPT_DICT, mode="absolute")
506+
fig.show()
507+
assert isinstance(fig, go.Figure)
508+
assert fig.data[3].text[0] == "0.04"
509+
pymoo_ev = PymooModelEvaluator(FOREST_PARAMS, X2())
510+
fig2 = pymoo_ev.plot_parameter_forest([0.04, 0.12, 21.0])
511+
assert list(fig2.data[3].x) == ["conductivity", "thickness", "temp_setpoint"]

0 commit comments

Comments
 (0)