Skip to content

Commit e73b8a1

Browse files
committed
⚡️ Refactor SobolSanalysis to enforce calc_second_order consistency and improve error handling in sensitivity analysis tests
1 parent d914826 commit e73b8a1

2 files changed

Lines changed: 108 additions & 50 deletions

File tree

corrai/sensitivity.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def analyze(
176176
prefix=method,
177177
)
178178
else:
179-
if method or agg_method_kwarg:
179+
if agg_method_kwarg is not None or (method is not None and method != "mean"):
180180
warnings.warn(
181181
"'method' or 'agg_method_kwarg' was provided but Model is static."
182182
" Arguments will be ignored"
@@ -406,14 +406,37 @@ class SobolSanalysis(Sanalysis):
406406
sensitivity analysis following the Sobol method. Sampling of the parameter
407407
space is performed using the Saltelli scheme, which ensures efficient
408408
estimation of first-order, second-order, and total-order Sobol indices.
409+
410+
Parameters
411+
----------
412+
parameters : list of Parameter
413+
Parameters that define the sampling space.
414+
model : Model
415+
Model instance to be simulated.
416+
simulation_options : dict, optional
417+
Options passed to the model simulation.
418+
calc_second_order : bool, default=True
419+
Whether to compute second-order (interaction) Sobol indices.
420+
This value is fixed at construction and used consistently for
421+
both sampling and analysis. Setting it to False reduces the
422+
required sample size from N*(2D+2) to N*(D+2).
423+
424+
Notes
425+
-----
426+
Sobol analysis requires a single call to :meth:`add_sample`. Calling it
427+
more than once would concatenate independent Saltelli matrices, which
428+
invalidates the variance decomposition. A ``ValueError`` is raised on
429+
any subsequent call.
409430
"""
410431

411432
def __init__(
412433
self,
413434
parameters: list[Parameter],
414435
model: Model,
415436
simulation_options: dict = None,
437+
calc_second_order: bool = True,
416438
):
439+
self._calc_second_order = calc_second_order
417440
super().__init__(parameters, model, simulation_options)
418441

419442
def _set_sampler(
@@ -430,15 +453,19 @@ def add_sample(
430453
N: int,
431454
simulate: bool = True,
432455
n_cpu: int = 1,
433-
*,
434-
calc_second_order: bool = True,
435456
**sample_kwargs,
436457
):
458+
if len(self.sampler.sample) > 0:
459+
raise ValueError(
460+
"SobolSanalysis does not support incremental sampling. "
461+
f"The sample already contains {len(self.sampler.sample)} rows. "
462+
"Create a new SobolSanalysis instance to generate a new sample."
463+
)
437464
super().add_sample(
438465
N=N,
439466
simulate=simulate,
440467
n_cpu=n_cpu,
441-
calc_second_order=calc_second_order,
468+
calc_second_order=self._calc_second_order,
442469
**sample_kwargs,
443470
)
444471

@@ -449,7 +476,6 @@ def analyze(
449476
agg_method_kwarg: dict = None,
450477
reference_time_series: pd.Series = None,
451478
freq: str | pd.Timedelta | dt.timedelta = None,
452-
calc_second_order: bool = True,
453479
**analyse_kwargs,
454480
):
455481
return super().analyze(
@@ -458,7 +484,7 @@ def analyze(
458484
agg_method_kwarg=agg_method_kwarg,
459485
reference_time_series=reference_time_series,
460486
freq=freq,
461-
calc_second_order=calc_second_order,
487+
calc_second_order=self._calc_second_order,
462488
**analyse_kwargs,
463489
)
464490

@@ -468,7 +494,6 @@ def plot_bar(
468494
sensitivity_metric: str = "ST",
469495
method: str = "mean",
470496
reference_time_series: pd.Series = None,
471-
calc_second_order: bool = True,
472497
unit: str = "",
473498
agg_method_kwarg: dict = None,
474499
title: str = None,
@@ -485,7 +510,6 @@ def plot_bar(
485510
agg_method_kwarg=agg_method_kwarg,
486511
title=title,
487512
plot_kwargs=plot_kwargs,
488-
calc_second_order=calc_second_order,
489513
**analyse_kwargs,
490514
)
491515

@@ -498,7 +522,6 @@ def plot_dynamic_metric(
498522
reference_time_series: pd.Series = None,
499523
unit: str = "",
500524
agg_method_kwarg: dict = None,
501-
calc_second_order: bool = True,
502525
title: str = None,
503526
plot_kwargs: dict = None,
504527
):
@@ -511,7 +534,6 @@ def plot_dynamic_metric(
511534
unit=unit,
512535
agg_method_kwarg=agg_method_kwarg,
513536
reference_time_series=reference_time_series,
514-
calc_second_order=calc_second_order,
515537
stacked=True,
516538
title=title,
517539
plot_kwargs=plot_kwargs,
@@ -528,6 +550,11 @@ def plot_s2_matrix(
528550
plot_kwargs: dict = None,
529551
**analyse_kwargs,
530552
):
553+
if not self._calc_second_order:
554+
raise ValueError(
555+
"plot_s2_matrix() requires second-order indices. "
556+
"Set calc_second_order=True when creating SobolSanalysis."
557+
)
531558
return super().salib_plot_matrix(
532559
indicator=indicator,
533560
sensitivity_method_name="Sobol",

tests/test_sensitivity.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import warnings
2+
13
import numpy as np
24
import pandas as pd
5+
import pytest
36

47
from corrai.base.parameter import Parameter
58
from corrai.base.model import IshigamiDynamic, Ishigami
@@ -29,55 +32,82 @@
2932

3033

3134
class TestSensitivity:
32-
def test_sanalysis_sobol_with_sobol_sampler(self):
33-
# sobol_analysis = SobolSanalysis(
34-
# parameters=PARAMETER_LIST,
35-
# model=IshigamiDynamic(),
36-
# simulation_options=SIMULATION_OPTIONS,
37-
# )
38-
#
39-
# sobol_analysis.add_sample(N=1000, n_cpu=1, calc_second_order=True, seed=42)
40-
# res = sobol_analysis.analyze("res", calc_second_order=True, seed=42)
41-
#
42-
# np.testing.assert_almost_equal(
43-
# res["mean_res"]["S1"],
44-
# np.array([0.33080399, 0.44206835, 0.00946747]),
45-
# )
46-
#
47-
# res = sobol_analysis.analyze("res", freq="h", calc_second_order=True, seed=42)
48-
# assert res.index.tolist() == [
49-
# pd.Timestamp("2009-01-01 00:00:00"),
50-
# pd.Timestamp("2009-01-01 01:00:00"),
51-
# pd.Timestamp("2009-01-01 02:00:00"),
52-
# pd.Timestamp("2009-01-01 03:00:00"),
53-
# pd.Timestamp("2009-01-01 04:00:00"),
54-
# pd.Timestamp("2009-01-01 05:00:00"),
55-
# ]
56-
#
57-
# sobol_analysis.plot_sample_hist(
58-
# "res", bins=10, reference_value=10, reference_label="ref"
59-
# )
60-
#
61-
# np.testing.assert_almost_equal(
62-
# res["2009-01-01 00:00:00"]["S1"],
63-
# np.array([0.33080399, 0.44206835, 0.00946747]),
64-
# decimal=3,
65-
# )
66-
#
35+
def test_sanalysis_sobol_static(self):
6736
sobol_analysis = SobolSanalysis(
6837
parameters=PARAMETER_LIST,
6938
model=Ishigami(),
39+
calc_second_order=True,
40+
)
41+
42+
sobol_analysis.add_sample(N=1000, n_cpu=1, seed=42)
43+
res = sobol_analysis.analyze("res", seed=42)
44+
45+
np.testing.assert_almost_equal(
46+
res["mean_res"]["S1"],
47+
np.array([0.33080399, 0.44206835, 0.00946747]),
48+
)
49+
50+
def test_sanalysis_sobol_dynamic(self):
51+
sobol_analysis = SobolSanalysis(
52+
parameters=PARAMETER_LIST,
53+
model=IshigamiDynamic(),
54+
simulation_options=SIMULATION_OPTIONS,
55+
calc_second_order=True,
7056
)
7157

72-
sobol_analysis.add_sample(N=1000, n_cpu=1, calc_second_order=True, seed=42)
73-
res = sobol_analysis.analyze("res", calc_second_order=True, seed=42)
58+
sobol_analysis.add_sample(N=1000, n_cpu=1, seed=42)
59+
res = sobol_analysis.analyze("res", seed=42)
7460

7561
np.testing.assert_almost_equal(
7662
res["mean_res"]["S1"],
7763
np.array([0.33080399, 0.44206835, 0.00946747]),
7864
)
7965

80-
assert True
66+
res_freq = sobol_analysis.analyze("res", freq="h", seed=42)
67+
assert res_freq.index.tolist() == [
68+
pd.Timestamp("2009-01-01 00:00:00"),
69+
pd.Timestamp("2009-01-01 01:00:00"),
70+
pd.Timestamp("2009-01-01 02:00:00"),
71+
pd.Timestamp("2009-01-01 03:00:00"),
72+
pd.Timestamp("2009-01-01 04:00:00"),
73+
pd.Timestamp("2009-01-01 05:00:00"),
74+
]
75+
np.testing.assert_almost_equal(
76+
res_freq["2009-01-01 00:00:00"]["S1"],
77+
np.array([0.33080399, 0.44206835, 0.00946747]),
78+
decimal=3,
79+
)
80+
81+
def test_sobol_incremental_sampling_raises(self):
82+
sobol_analysis = SobolSanalysis(
83+
parameters=PARAMETER_LIST,
84+
model=Ishigami(),
85+
)
86+
sobol_analysis.add_sample(N=64, n_cpu=1, seed=42)
87+
with pytest.raises(ValueError, match="does not support incremental sampling"):
88+
sobol_analysis.add_sample(N=64, n_cpu=1, seed=0)
89+
90+
def test_sobol_plot_s2_matrix_raises_when_no_second_order(self):
91+
sobol_analysis = SobolSanalysis(
92+
parameters=PARAMETER_LIST,
93+
model=Ishigami(),
94+
calc_second_order=False,
95+
)
96+
sobol_analysis.add_sample(N=64, n_cpu=1, seed=42)
97+
with pytest.raises(ValueError, match="requires second-order indices"):
98+
sobol_analysis.plot_s2_matrix()
99+
100+
def test_sobol_static_model_no_spurious_warning(self):
101+
sobol_analysis = SobolSanalysis(
102+
parameters=PARAMETER_LIST,
103+
model=Ishigami(),
104+
)
105+
sobol_analysis.add_sample(N=64, n_cpu=1, seed=42)
106+
with warnings.catch_warnings(record=True) as caught:
107+
warnings.simplefilter("always")
108+
sobol_analysis.analyze("res", seed=42)
109+
our_warnings = [w for w in caught if "sensitivity" in str(w.filename)]
110+
assert len(our_warnings) == 0, f"Unexpected warnings from sensitivity.py: {our_warnings}"
81111

82112
def test_sanalysis_morris(self):
83113
morris_analysis = MorrisSanalysis(
@@ -202,8 +232,9 @@ def test_sobol_s2_matrix(self):
202232
parameters=PARAMETER_LIST,
203233
model=IshigamiDynamic(),
204234
simulation_options=SIMULATION_OPTIONS,
235+
calc_second_order=True,
205236
)
206-
sobol_analysis.add_sample(N=2**2, n_cpu=1, calc_second_order=True)
237+
sobol_analysis.add_sample(N=2**2, n_cpu=1)
207238
fig_matrix = sobol_analysis.plot_s2_matrix()
208239
assert fig_matrix["layout"]["title"]["text"] == (
209240
"Sobol mean res " "- 2nd order interactions"
@@ -321,7 +352,7 @@ def test_sobol_plot_bar_plot_kwargs(self):
321352
parameters=PARAMETER_LIST,
322353
model=Ishigami(),
323354
)
324-
sobol_analysis.add_sample(N=2**4, n_cpu=1, calc_second_order=True, seed=42)
355+
sobol_analysis.add_sample(N=2**4, n_cpu=1, seed=42)
325356
fig = sobol_analysis.plot_bar(
326357
plot_kwargs={
327358
"title": "My Custom Title",

0 commit comments

Comments
 (0)