|
10 | 10 |
|
11 | 11 | import datetime |
12 | 12 | from enum import Enum |
| 13 | +from pathlib import Path |
13 | 14 | from typing import Any, Literal |
14 | 15 |
|
15 | 16 | import numpy as np |
@@ -198,10 +199,25 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series): |
198 | 199 | This method should be called from the child class's fit method after validation. |
199 | 200 | """ |
200 | 201 | from autogluon.tabular import TabularPredictor |
201 | | - from autogluon.tabular.models import TabPFNV2Model |
| 202 | + from autogluon.tabular.models import RealTabPFNv2Model, RealTabPFNv25Model |
202 | 203 |
|
203 | 204 | from tabpfn_extensions.post_hoc_ensembles.utils import search_space_func |
204 | 205 |
|
| 206 | + # Route to the AutoGluon TabPFN model class that matches the requested |
| 207 | + # TabPFN model version. Each class ships with the correct per-version |
| 208 | + # max_rows/max_features/max_classes limits, so we no longer need to |
| 209 | + # override them via ag_args_fit. |
| 210 | + if self.model_version == ModelVersion.V2: |
| 211 | + ag_model_class = RealTabPFNv2Model |
| 212 | + elif self.model_version == ModelVersion.V2_5: |
| 213 | + ag_model_class = RealTabPFNv25Model |
| 214 | + else: |
| 215 | + raise NotImplementedError( |
| 216 | + f"AutoTabPFN does not support TabPFN model version " |
| 217 | + f"{self.model_version.value!r} yet. Supported versions: " |
| 218 | + f"{ModelVersion.V2.value!r}, {ModelVersion.V2_5.value!r}.", |
| 219 | + ) |
| 220 | + |
205 | 221 | if isinstance(X, pd.DataFrame): |
206 | 222 | training_df = X.copy() |
207 | 223 | self._column_names = X.columns.tolist() |
@@ -247,30 +263,33 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series): |
247 | 263 | **self.get_task_args_(), |
248 | 264 | } |
249 | 265 |
|
250 | | - def _patch_ag_args_fit_inplace(config: dict[str, Any]) -> None: |
251 | | - """Patch AutoGluon's per-model params_aux for TabPFN sub-models. |
252 | | -
|
253 | | - - Disable AutoGluon's static max_rows / max_features / max_classes |
254 | | - asserts so TabPFN's own per-checkpoint validation is the single |
255 | | - authority. TODO: Fix upstream in AutoGluon's `TabPFNV2Model`. A |
256 | | - single class handles all v2.x checkpoints with v2-era limits |
257 | | - hardcoded in `_get_default_auxiliary_params`, which is wrong for |
258 | | - v2.5+. Until that lands, we override per sub-model here. |
259 | | - - Forward the user's `ignore_pretraining_limits` flag to TabPFN. |
| 266 | + def _adapt_config_for_autogluon_inplace(config: dict[str, Any]) -> None: |
| 267 | + """Forward `ignore_pretraining_limits` and translate `model_path`. |
| 268 | +
|
| 269 | + - AutoGluon 1.5 expects the checkpoint to be passed as |
| 270 | + ``zip_model_path=[classification_ckpt, regression_ckpt]`` (just |
| 271 | + the filenames; AG joins with its own resolved cache dir). The |
| 272 | + search space produces a TabPFN-compatible ``model_path=<abs path>`` |
| 273 | + so that it can also be consumed directly by core TabPFN; here we |
| 274 | + translate it to the form AG expects and drop the original key. |
| 275 | + - Forward the user's ``ignore_pretraining_limits`` flag into AG's |
| 276 | + ``ag_args_fit.ignore_constraints``. |
260 | 277 | """ |
261 | 278 | ag_args_fit = config.setdefault("ag_args_fit", {}) |
262 | | - ag_args_fit["max_rows"] = None |
263 | | - ag_args_fit["max_features"] = None |
264 | | - ag_args_fit["max_classes"] = None |
265 | 279 | ag_args_fit["ignore_constraints"] = self.ignore_pretraining_limits |
266 | 280 |
|
| 281 | + full_path = config.pop("model_path", None) |
| 282 | + if full_path is not None: |
| 283 | + ckpt_name = Path(full_path).name |
| 284 | + config["zip_model_path"] = [ckpt_name, ckpt_name] |
| 285 | + |
267 | 286 | if isinstance(tabpfn_configs, list): |
268 | 287 | for cfg in tabpfn_configs: |
269 | | - _patch_ag_args_fit_inplace(cfg) |
| 288 | + _adapt_config_for_autogluon_inplace(cfg) |
270 | 289 | else: |
271 | | - _patch_ag_args_fit_inplace(tabpfn_configs) |
| 290 | + _adapt_config_for_autogluon_inplace(tabpfn_configs) |
272 | 291 |
|
273 | | - hyperparameters = {TabPFNV2Model: tabpfn_configs} |
| 292 | + hyperparameters = {ag_model_class: tabpfn_configs} |
274 | 293 | if isinstance(self.presets, str) and self.presets == "extreme_quality": |
275 | 294 | raise ValueError( |
276 | 295 | "Extreme quality preset is not supported at the moment, as it does not " |
|
0 commit comments