Skip to content

Commit b29e088

Browse files
authored
[MNT] isolate captum soft dependency in tests (#613)
isolates the `captum` soft dependency in tests
1 parent 52968eb commit b29e088

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

tests/test_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010
from scipy.stats import uniform
11+
from skbase.utils.dependencies import _check_soft_dependencies
1112
from sklearn.metrics import accuracy_score, r2_score
1213
from sklearn.model_selection import KFold
1314

@@ -542,6 +543,10 @@ def _test_captum(
542543
assert exp.shape[1] == tabular_model.model.hparams.continuous_dim + tabular_model.model.hparams.categorical_dim
543544

544545

546+
@pytest.mark.skipif(
547+
not _check_soft_dependencies("captum", severity="none"),
548+
reason="skip captum integration test if captum is not installed",
549+
)
545550
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_CAPTUM_TEST)
546551
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
547552
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"], []])
@@ -592,6 +597,10 @@ def test_captum_integration_regression(
592597
)
593598

594599

600+
@pytest.mark.skipif(
601+
not _check_soft_dependencies("captum", severity="none"),
602+
reason="skip captum integration test if captum is not installed",
603+
)
595604
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_CAPTUM_TEST)
596605
@pytest.mark.parametrize(
597606
"continuous_cols",

0 commit comments

Comments
 (0)