Skip to content

Commit 7ff26e5

Browse files
adrian-priorclaude
andcommitted
Add per-version AG class routing test
Verifies, per Philipp's review request, that AutoTabPFN selects the AutoGluon TabPFN model class corresponding to model_version: - V2 -> RealTabPFNv2Model - V2_5 -> RealTabPFNv25Model Stubs TabularPredictor so the test runs in <0.1s without a real fit; just captures the hyperparameters dict and asserts the keying class. Catches any future refactor that accidentally hardcodes one class for both versions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 40def8d commit 7ff26e5

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

tests/test_post_hoc_ensembles.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,59 @@ def test_ignore_pretraining_limits_allows_large_dataset(
258258
# AssertionError: ag.max_rows=10000 but...
259259
with pytest.raises(RuntimeError):
260260
model_no_flag.fit(X, y)
261+
262+
@pytest.mark.parametrize(
263+
("model_version", "expected_class_name"),
264+
[
265+
(ModelVersion.V2, "RealTabPFNv2Model"),
266+
(ModelVersion.V2_5, "RealTabPFNv25Model"),
267+
],
268+
)
269+
def test_routes_to_per_version_autogluon_class(
270+
self,
271+
monkeypatch: pytest.MonkeyPatch,
272+
model_version: ModelVersion,
273+
expected_class_name: str,
274+
):
275+
"""``model_version`` should select the AutoGluon TabPFN model class
276+
whose per-version max_rows/max_features/max_classes limits match: V2
277+
-> RealTabPFNv2Model, V2_5 -> RealTabPFNv25Model. We stub
278+
``TabularPredictor`` so the test does not actually fit anything; it
279+
just captures the hyperparameters dict and asserts the keying class.
280+
"""
281+
captured: dict[str, object] = {}
282+
283+
class StubPredictor:
284+
def __init__(self, *args, **kwargs):
285+
pass
286+
287+
def fit(self, *args, **kwargs):
288+
captured["hyperparameters"] = kwargs.get("hyperparameters")
289+
return self
290+
291+
def features(self):
292+
return ["a", "b"]
293+
294+
# AutoTabPFN imports `TabularPredictor` lazily inside fit(), so we
295+
# patch the attribute on the source module that the import reads from.
296+
import autogluon.tabular
297+
298+
monkeypatch.setattr(autogluon.tabular, "TabularPredictor", StubPredictor)
299+
300+
X = pd.DataFrame(np.random.randn(40, 2), columns=["a", "b"])
301+
y = pd.Series([0, 1] * 20)
302+
303+
clf = AutoTabPFNClassifier(
304+
model_version=model_version,
305+
n_ensemble_models=1,
306+
max_time=1,
307+
)
308+
clf.fit(X, y)
309+
310+
hps = captured["hyperparameters"]
311+
assert (
312+
hps is not None
313+
), "TabularPredictor.fit was not called with hyperparameters"
314+
classes = list(hps.keys())
315+
assert len(classes) == 1
316+
assert classes[0].__name__ == expected_class_name

0 commit comments

Comments
 (0)