Skip to content

Commit 89c6a34

Browse files
authored
Bump AutoGluon to 1.5 and route per-version to RealTabPFNv2/v25 classes (#274)
AutoGluon 1.5 ships dedicated `RealTabPFNv2Model` and `RealTabPFNv25Model` classes with per-version `max_rows` / `max_features` / `max_classes` baked in. Previously AutoTabPFN always used the single `TabPFNV2Model` class (v2-era 10k / 500 / 10 limits) and worked around that with a `ag_args_fit = {max_rows: None, ...}` neutralization (PR #272), delegating gating to TabPFN's own inference_config. We can now select the right AG class per `model_version` and drop the neutralization entirely.
1 parent a69f681 commit 89c6a34

20 files changed

Lines changed: 351 additions & 1770 deletions

File tree

.github/workflows/pull_request.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,13 @@ jobs:
3434
matrix:
3535
include:
3636
- os: ubuntu-latest
37-
python-version: "3.9"
37+
python-version: "3.10"
3838
dependency-set: minimum
39-
- os: macos-15-intel # We need x86 as ARM is python>= 3.11 only.
40-
# https://github.com/actions/setup-python/issues/855
41-
python-version: "3.9"
39+
- os: macos-latest
40+
python-version: "3.10"
4241
dependency-set: minimum
4342
- os: windows-latest
44-
python-version: "3.9"
43+
python-version: "3.10"
4544
dependency-set: minimum
4645
- os: ubuntu-latest
4746
python-version: "3.13"
@@ -64,7 +63,6 @@ jobs:
6463
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
6564
with:
6665
python-version: ${{ matrix.python-version }}
67-
architecture: x64
6866

6967
- name: Install uv
7068
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0

pyproject.toml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ dependencies = [
1010
"pandas>=1.4.0",
1111
"scikit-learn>=1.6.0",
1212
"scipy>=1.11.1",
13-
"tabpfn>=6.0.5",
14-
"tabpfn-common-utils[telemetry-interactive]>=0.2.0",
13+
"tabpfn>=6.3.2",
14+
"tabpfn-common-utils[telemetry-interactive]>=0.2.13",
1515
]
1616

17-
requires-python = ">=3.9"
17+
requires-python = ">=3.10"
1818
authors = [
1919
{ name = "Noah Hollmann", email = "noah.hollmann@charite.de" },
2020
{ name = "Leo Grinsztajn" },
@@ -43,7 +43,6 @@ classifiers = [
4343
'Operating System :: Unix',
4444
'Operating System :: MacOS',
4545
'Programming Language :: Python :: 3',
46-
'Programming Language :: Python :: 3.9',
4746
'Programming Language :: Python :: 3.10',
4847
'Programming Language :: Python :: 3.11',
4948
'Programming Language :: Python :: 3.12',
@@ -64,7 +63,7 @@ interpretability = [
6463
post_hoc_ensembles = [
6564
"llvmlite",
6665
"hyperopt>=0.2.7",
67-
"autogluon.tabular==1.4.0"
66+
"autogluon.tabular>=1.5.0,<1.6"
6867
]
6968
hpo = [
7069
"hyperopt>=0.2.7",
@@ -89,7 +88,7 @@ all = [
8988
"hyperopt>=0.2.7",
9089
# https://discuss.python.org/t/pkg-resources-removal-how-to-go-from-there/106079
9190
"setuptools>=67.0.0,<82",
92-
"autogluon.tabular==1.4.0",
91+
"autogluon.tabular>=1.5.0,<1.6",
9392
# scikit-survival not included; install manually if you need SurvivalTabPFN.
9493
]
9594

@@ -117,7 +116,7 @@ addopts = "--durations=10 -vv"
117116

118117
# https://github.com/astral-sh/ruff
119118
[tool.ruff]
120-
target-version = "py39"
119+
target-version = "py310"
121120
line-length = 88
122121
output-format = "full"
123122
src = ["src", "tests", "examples"]
@@ -329,7 +328,7 @@ convention = "google"
329328
max-args = 10 # Changed from default of 5
330329

331330
[tool.mypy]
332-
python_version = "3.9"
331+
python_version = "3.10"
333332
packages = ["src/tabpfn_extensions", "tests"]
334333

335334
show_error_codes = true
@@ -368,7 +367,7 @@ ignore_missing_imports = true
368367
[tool.pyright]
369368
include = ["src", "tests"]
370369

371-
pythonVersion = "3.9"
370+
pythonVersion = "3.10"
372371
typeCheckingMode = "strict"
373372

374373
strictListInference = true

src/tabpfn_extensions/hpo/search_space.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,24 @@ def get_param_grid_hyperopt(
143143
if model_dir is None:
144144
model_dir = get_cache_dir()
145145

146-
if task_type == "multiclass" and model_version == ModelVersion.V2:
147-
model_source = ModelSource.get_classifier_v2()
148-
elif task_type == "multiclass" and model_version == ModelVersion.V2_5:
149-
model_source = ModelSource.get_classifier_v2_5()
150-
elif task_type == "regression":
151-
if model_version == ModelVersion.V2:
152-
model_source = ModelSource.get_regressor_v2()
153-
elif model_version == ModelVersion.V2_5:
154-
model_source = ModelSource.get_regressor_v2_5()
146+
# Resolve the model source for the (task_type, model_version) pair. Any
147+
# combination we don't explicitly support below raises before we try to
148+
# use `model_source`.
149+
model_source_lookup = {
150+
("multiclass", ModelVersion.V2): ModelSource.get_classifier_v2,
151+
("multiclass", ModelVersion.V2_5): ModelSource.get_classifier_v2_5,
152+
("regression", ModelVersion.V2): ModelSource.get_regressor_v2,
153+
("regression", ModelVersion.V2_5): ModelSource.get_regressor_v2_5,
154+
}
155+
try:
156+
model_source = model_source_lookup[(task_type, model_version)]()
157+
except KeyError as err:
158+
raise NotImplementedError(
159+
f"No hpo search space is defined for task type {task_type!r} and "
160+
f"model version {model_version!r}."
161+
) from err
162+
163+
if task_type == "regression":
155164
search_space["inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS"] = hp.choice(
156165
"REGRESSION_Y_PREPROCESS_TRANSFORMS",
157166
[
@@ -162,12 +171,6 @@ def get_param_grid_hyperopt(
162171
],
163172
)
164173

165-
else:
166-
raise NotImplementedError(
167-
f"No hpo search space is defined for task type {task_type} and "
168-
f"model version {model_version}."
169-
)
170-
171174
# Make sure models are downloaded.
172175
if download_models_if_missing:
173176
for ckpt_name in model_source.filenames:

src/tabpfn_extensions/hpo/tuned_tabpfn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
from __future__ import annotations
4040

4141
import logging
42+
from collections.abc import Callable
4243
from enum import Enum
43-
from typing import Any, Callable
44+
from typing import Any
4445

4546
import numpy as np
4647
import torch
@@ -183,7 +184,7 @@ def _optimize(self, X: np.ndarray, y: np.ndarray, task_type: str):
183184
for k, v_item in self.search_space.items():
184185
if isinstance(v_item, list):
185186
custom_space[k] = hp.choice(k, v_item)
186-
elif isinstance(v_item, (int, float, bool, str)) or v_item is None:
187+
elif isinstance(v_item, int | float | bool | str) or v_item is None:
187188
custom_space[k] = v_item
188189
else:
189190
custom_space[k] = v_item

src/tabpfn_extensions/interpretability/shap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333

3434
from __future__ import annotations
3535

36+
from collections.abc import Callable
3637
from multiprocessing import Pool
37-
from typing import Any, Callable
38+
from typing import Any
3839

3940
import matplotlib.pyplot as plt
4041
import numpy as np

src/tabpfn_extensions/interpretability/shapiq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class index will be set to 1 per default for classification models. This argumen
7878
data = data.values
7979

8080
# make labels to array if it is a pandas Series
81-
if isinstance(labels, (pd.Series, pd.DataFrame)):
81+
if isinstance(labels, pd.Series | pd.DataFrame):
8282
labels = labels.values
8383

8484
# TabPFNExplainer is directly available in the shapiq module

src/tabpfn_extensions/many_class/_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import sys
43
from collections.abc import Iterable, Mapping
54
from dataclasses import dataclass
65
from typing import TYPE_CHECKING, Any
@@ -15,8 +14,7 @@
1514

1615
def _dataclass_kwargs() -> dict[str, Any]:
1716
kwargs: dict[str, Any] = {}
18-
if sys.version_info >= (3, 10): # pragma: no cover - environment dependent
19-
kwargs["slots"] = True
17+
kwargs["slots"] = True
2018
return kwargs
2119

2220

src/tabpfn_extensions/misc/sklearn_compat.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import platform
1919
import sys
2020
import types
21+
from collections.abc import Callable
2122
from dataclasses import dataclass, field
22-
from typing import Callable, Literal
23+
from typing import Literal
2324

2425
import sklearn
2526
from sklearn.utils.fixes import parse_version
@@ -34,8 +35,6 @@
3435

3536
# tags infrastructure
3637
def _dataclass_args():
37-
if sys.version_info < (3, 10):
38-
return {}
3938
return {"slots": True}
4039

4140

@@ -196,7 +195,7 @@ def _is_fitted(estimator, attributes=None, all_or_any=all):
196195
Whether the estimator is fitted.
197196
"""
198197
if attributes is not None:
199-
if not isinstance(attributes, (list, tuple)):
198+
if not isinstance(attributes, list | tuple):
200199
attributes = [attributes]
201200
return all_or_any([hasattr(estimator, attr) for attr in attributes])
202201

src/tabpfn_extensions/post_hoc_ensembles/sklearn_interface.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import datetime
1212
from enum import Enum
13+
from pathlib import Path
1314
from typing import Any, Literal
1415

1516
import numpy as np
@@ -198,10 +199,25 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series):
198199
This method should be called from the child class's fit method after validation.
199200
"""
200201
from autogluon.tabular import TabularPredictor
201-
from autogluon.tabular.models import TabPFNV2Model
202+
from autogluon.tabular.models import RealTabPFNv2Model, RealTabPFNv25Model
202203

203204
from tabpfn_extensions.post_hoc_ensembles.utils import search_space_func
204205

206+
# Route to the AutoGluon TabPFN model class that matches the requested
207+
# TabPFN model version. Each class ships with the correct per-version
208+
# max_rows/max_features/max_classes limits, so we no longer need to
209+
# override them via ag_args_fit.
210+
if self.model_version == ModelVersion.V2:
211+
ag_model_class = RealTabPFNv2Model
212+
elif self.model_version == ModelVersion.V2_5:
213+
ag_model_class = RealTabPFNv25Model
214+
else:
215+
raise NotImplementedError(
216+
f"AutoTabPFN does not support TabPFN model version "
217+
f"{self.model_version.value!r} yet. Supported versions: "
218+
f"{ModelVersion.V2.value!r}, {ModelVersion.V2_5.value!r}.",
219+
)
220+
205221
if isinstance(X, pd.DataFrame):
206222
training_df = X.copy()
207223
self._column_names = X.columns.tolist()
@@ -247,30 +263,33 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series):
247263
**self.get_task_args_(),
248264
}
249265

250-
def _patch_ag_args_fit_inplace(config: dict[str, Any]) -> None:
251-
"""Patch AutoGluon's per-model params_aux for TabPFN sub-models.
252-
253-
- Disable AutoGluon's static max_rows / max_features / max_classes
254-
asserts so TabPFN's own per-checkpoint validation is the single
255-
authority. TODO: Fix upstream in AutoGluon's `TabPFNV2Model`. A
256-
single class handles all v2.x checkpoints with v2-era limits
257-
hardcoded in `_get_default_auxiliary_params`, which is wrong for
258-
v2.5+. Until that lands, we override per sub-model here.
259-
- Forward the user's `ignore_pretraining_limits` flag to TabPFN.
266+
def _adapt_config_for_autogluon_inplace(config: dict[str, Any]) -> None:
267+
"""Forward `ignore_pretraining_limits` and translate `model_path`.
268+
269+
- AutoGluon 1.5 expects the checkpoint to be passed as
270+
``zip_model_path=[classification_ckpt, regression_ckpt]`` (just
271+
the filenames; AG joins with its own resolved cache dir). The
272+
search space produces a TabPFN-compatible ``model_path=<abs path>``
273+
so that it can also be consumed directly by core TabPFN; here we
274+
translate it to the form AG expects and drop the original key.
275+
- Forward the user's ``ignore_pretraining_limits`` flag into AG's
276+
``ag_args_fit.ignore_constraints``.
260277
"""
261278
ag_args_fit = config.setdefault("ag_args_fit", {})
262-
ag_args_fit["max_rows"] = None
263-
ag_args_fit["max_features"] = None
264-
ag_args_fit["max_classes"] = None
265279
ag_args_fit["ignore_constraints"] = self.ignore_pretraining_limits
266280

281+
full_path = config.pop("model_path", None)
282+
if full_path is not None:
283+
ckpt_name = Path(full_path).name
284+
config["zip_model_path"] = [ckpt_name, ckpt_name]
285+
267286
if isinstance(tabpfn_configs, list):
268287
for cfg in tabpfn_configs:
269-
_patch_ag_args_fit_inplace(cfg)
288+
_adapt_config_for_autogluon_inplace(cfg)
270289
else:
271-
_patch_ag_args_fit_inplace(tabpfn_configs)
290+
_adapt_config_for_autogluon_inplace(tabpfn_configs)
272291

273-
hyperparameters = {TabPFNV2Model: tabpfn_configs}
292+
hyperparameters = {ag_model_class: tabpfn_configs}
274293
if isinstance(self.presets, str) and self.presets == "extreme_quality":
275294
raise ValueError(
276295
"Extreme quality preset is not supported at the moment, as it does not "

src/tabpfn_extensions/post_hoc_ensembles/utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,10 @@ def prepare_tabpfnv2_config(
6767
else:
6868
config.pop("balance_probabilities", None)
6969

70-
# TODO: Enable RF-PFN at some point
71-
config["model_type"] = "single"
72-
73-
# Special case for dt_pfn
74-
# TODO: This code is unused until we support RF-PFN
75-
if config.get("model_type") == "dt_pfn":
76-
config["n_ensemble_repeats"] = config["n_estimators"]
77-
config["n_estimators"] = 1
78-
79-
# Remove deprecated keys
70+
# Remove search-space-only keys that are not valid TabPFNClassifier kwargs.
71+
# `model_type` is a placeholder for future RF-PFN routing (currently unused),
72+
# and `max_depth` was for the dt_pfn variant.
73+
config.pop("model_type", None)
8074
config.pop("max_depth", None)
8175

8276
return config

0 commit comments

Comments
 (0)