Skip to content

Commit f733dd0

Browse files
authored
[ENH] add conditional test skips to estimator specific tests (#607)
adds conditional test skips to estimator specific tests to manage test time, see #603. Uses `scikit-base` to check for changes in estimators.
1 parent d2d21da commit f733dd0

10 files changed

Lines changed: 77 additions & 0 deletions

.github/workflows/testing.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ jobs:
2828
steps:
2929
- uses: actions/checkout@v6
3030

31+
- run: git remote set-branches origin 'main'
32+
33+
- run: git fetch --depth 1
34+
3135
- name: Install uv
3236
uses: astral-sh/setup-uv@v7
3337
with:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ dependencies = [
5656
"einops>=0.6.0,<0.8.0",
5757
"fsspec>=2022.5.0,<2024.4.0; python_version == '3.8'",
5858
"rich",
59+
"scikit-base",
5960
]
6061

6162

tests/test_autoint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
78
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
89
from pytorch_tabular.models import AutoIntConfig
910

1011

12+
@pytest.mark.skipif(
13+
not _is_module_changed("pytorch_tabular.models.autoint"),
14+
reason="run test only if autoint module is changed",
15+
)
1116
@pytest.mark.parametrize("multi_target", [True, False])
1217
@pytest.mark.parametrize(
1318
"continuous_cols", [["AveRooms", "AveBedrms", "Population", "AveOccup", "Latitude", "Longitude"]]
@@ -78,6 +83,10 @@ def test_regression(
7883
assert pred_df.shape[0] == test.shape[0]
7984

8085

86+
@pytest.mark.skipif(
87+
not _is_module_changed("pytorch_tabular.models.autoint"),
88+
reason="run test only if autoint module is changed",
89+
)
8190
@pytest.mark.parametrize("multi_target", [False, True])
8291
@pytest.mark.parametrize(
8392
"continuous_cols",

tests/test_danet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
@@ -10,6 +11,10 @@
1011
# from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer
1112

1213

14+
@pytest.mark.skipif(
15+
not _is_module_changed("pytorch_tabular.models.danet"),
16+
reason="run test only if danet module is changed",
17+
)
1318
@pytest.mark.parametrize("multi_target", [True, False])
1419
@pytest.mark.parametrize(
1520
"continuous_cols",
@@ -80,6 +85,10 @@ def test_regression(
8085
assert pred_df.shape[0] == test.shape[0]
8186

8287

88+
@pytest.mark.skipif(
89+
not _is_module_changed("pytorch_tabular.models.danet"),
90+
reason="run test only if danet module is changed",
91+
)
8392
@pytest.mark.parametrize("multi_target", [False, True])
8493
@pytest.mark.parametrize(
8594
"continuous_cols",

tests/test_ft_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
@@ -9,6 +10,10 @@
910
from pytorch_tabular.models import FTTransformerConfig
1011

1112

13+
@pytest.mark.skipif(
14+
not _is_module_changed("pytorch_tabular.models.ft_transformer"),
15+
reason="run test only if ft_transformer module is changed",
16+
)
1217
@pytest.mark.parametrize("multi_target", [True, False])
1318
@pytest.mark.parametrize(
1419
"continuous_cols",
@@ -86,6 +91,10 @@ def test_regression(
8691
assert pred_df.shape[0] == test.shape[0]
8792

8893

94+
@pytest.mark.skipif(
95+
not _is_module_changed("pytorch_tabular.models.ft_transformer"),
96+
reason="run test only if ft_transformer module is changed",
97+
)
8998
@pytest.mark.parametrize("multi_target", [False, True])
9099
@pytest.mark.parametrize(
91100
"continuous_cols",

tests/test_gandalf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
@@ -10,6 +11,10 @@
1011
# from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer
1112

1213

14+
@pytest.mark.skipif(
15+
not _is_module_changed("pytorch_tabular.models.gandalf"),
16+
reason="run test only if gandalf module is changed",
17+
)
1318
@pytest.mark.parametrize("multi_target", [True, False])
1419
@pytest.mark.parametrize(
1520
"continuous_cols",
@@ -80,6 +85,10 @@ def test_regression(
8085
assert pred_df.shape[0] == test.shape[0]
8186

8287

88+
@pytest.mark.skipif(
89+
not _is_module_changed("pytorch_tabular.models.gandalf"),
90+
reason="run test only if gandalf module is changed",
91+
)
8392
@pytest.mark.parametrize("multi_target", [False, True])
8493
@pytest.mark.parametrize(
8594
"continuous_cols",

tests/test_gate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
@@ -10,6 +11,10 @@
1011
# from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer
1112

1213

14+
@pytest.mark.skipif(
15+
not _is_module_changed("pytorch_tabular.models.gate"),
16+
reason="run test only if gate module is changed",
17+
)
1318
@pytest.mark.parametrize("multi_target", [True, False])
1419
@pytest.mark.parametrize(
1520
"continuous_cols",
@@ -85,6 +90,10 @@ def test_regression(
8590
assert pred_df.shape[0] == test.shape[0]
8691

8792

93+
@pytest.mark.skipif(
94+
not _is_module_changed("pytorch_tabular.models.gate"),
95+
reason="run test only if gate module is changed",
96+
)
8897
@pytest.mark.parametrize("multi_target", [False, True])
8998
@pytest.mark.parametrize(
9099
"continuous_cols",

tests/test_model_stacking.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from skbase.utils.git_diff import _is_module_changed
23
import pytest
34
import torch
45
from sklearn.preprocessing import PowerTransformer
@@ -50,6 +51,10 @@ def get_model_configs(task):
5051
return [model_config(task) for model_config in all_model_configs]
5152

5253

54+
@pytest.mark.skipif(
55+
not _is_module_changed("pytorch_tabular.models.stacking"),
56+
reason="run test only if stacking module is changed",
57+
)
5358
@pytest.mark.parametrize("multi_target", [True, False])
5459
@pytest.mark.parametrize(
5560
"continuous_cols",
@@ -163,6 +168,10 @@ def test_regression(
163168
assert pred_df.shape[0] == test.shape[0]
164169

165170

171+
@pytest.mark.skipif(
172+
not _is_module_changed("pytorch_tabular.models.stacking"),
173+
reason="run test only if stacking module is changed",
174+
)
166175
@pytest.mark.parametrize("multi_target", [False, True])
167176
@pytest.mark.parametrize(
168177
"continuous_cols",

tests/test_tabnet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
78
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
89
from pytorch_tabular.models import TabNetModelConfig
910

1011

12+
@pytest.mark.skipif(
13+
not _is_module_changed("pytorch_tabular.models.tabnet"),
14+
reason="run test only if tabnet module is changed",
15+
)
1116
@pytest.mark.parametrize("multi_target", [True, False])
1217
@pytest.mark.parametrize(
1318
"continuous_cols",
@@ -78,6 +83,10 @@ def test_regression(
7883
assert pred_df.shape[0] == test.shape[0]
7984

8085

86+
@pytest.mark.skipif(
87+
not _is_module_changed("pytorch_tabular.models.tabnet"),
88+
reason="run test only if tabnet module is changed",
89+
)
8190
@pytest.mark.parametrize("multi_target", [False, True])
8291
@pytest.mark.parametrize(
8392
"continuous_cols",

tests/test_tabtransformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Tests for `pytorch_tabular` package."""
33

4+
from skbase.utils.git_diff import _is_module_changed
45
import pytest
56

67
from pytorch_tabular import TabularModel
@@ -9,6 +10,10 @@
910
from pytorch_tabular.models import TabTransformerConfig
1011

1112

13+
@pytest.mark.skipif(
14+
not _is_module_changed("pytorch_tabular.models.tab_transformer"),
15+
reason="run test only if tab_transformer module is changed",
16+
)
1217
@pytest.mark.parametrize("multi_target", [True, False])
1318
@pytest.mark.parametrize(
1419
"continuous_cols",
@@ -84,6 +89,10 @@ def test_regression(
8489
assert pred_df.shape[0] == test.shape[0]
8590

8691

92+
@pytest.mark.skipif(
93+
not _is_module_changed("pytorch_tabular.models.tab_transformer"),
94+
reason="run test only if tab_transformer module is changed",
95+
)
8796
@pytest.mark.parametrize("multi_target", [False, True])
8897
@pytest.mark.parametrize(
8998
"continuous_cols",

0 commit comments

Comments
 (0)