Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ jobs:
matrix:
include:
- os: ubuntu-latest
python-version: "3.9"
python-version: "3.10"
dependency-set: minimum
- os: macos-15-intel # We need x86 as ARM is python>= 3.11 only.
# https://github.com/actions/setup-python/issues/855
python-version: "3.9"
- os: macos-latest
python-version: "3.10"
dependency-set: minimum
- os: windows-latest
python-version: "3.9"
python-version: "3.10"
dependency-set: minimum
- os: ubuntu-latest
python-version: "3.13"
Expand All @@ -64,7 +63,6 @@ jobs:
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Install uv
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
Expand Down
17 changes: 8 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ dependencies = [
"pandas>=1.4.0",
"scikit-learn>=1.6.0",
"scipy>=1.11.1",
"tabpfn>=6.0.5",
"tabpfn-common-utils[telemetry-interactive]>=0.2.0",
"tabpfn>=6.3.2",
"tabpfn-common-utils[telemetry-interactive]>=0.2.13",
]

requires-python = ">=3.9"
requires-python = ">=3.10"
Comment thread
adrian-prior marked this conversation as resolved.
authors = [
{ name = "Noah Hollmann", email = "noah.hollmann@charite.de" },
{ name = "Leo Grinsztajn" },
Expand Down Expand Up @@ -43,7 +43,6 @@ classifiers = [
'Operating System :: Unix',
'Operating System :: MacOS',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
Expand All @@ -64,7 +63,7 @@ interpretability = [
post_hoc_ensembles = [
"llvmlite",
"hyperopt>=0.2.7",
"autogluon.tabular==1.4.0"
"autogluon.tabular>=1.5.0,<1.6"
]
hpo = [
"hyperopt>=0.2.7",
Expand All @@ -89,7 +88,7 @@ all = [
"hyperopt>=0.2.7",
# https://discuss.python.org/t/pkg-resources-removal-how-to-go-from-there/106079
"setuptools>=67.0.0,<82",
"autogluon.tabular==1.4.0",
"autogluon.tabular>=1.5.0,<1.6",
# scikit-survival not included; install manually if you need SurvivalTabPFN.
]

Expand Down Expand Up @@ -117,7 +116,7 @@ addopts = "--durations=10 -vv"

# https://github.com/astral-sh/ruff
[tool.ruff]
target-version = "py39"
target-version = "py310"
line-length = 88
output-format = "full"
src = ["src", "tests", "examples"]
Expand Down Expand Up @@ -329,7 +328,7 @@ convention = "google"
max-args = 10 # Changed from default of 5

[tool.mypy]
python_version = "3.9"
python_version = "3.10"
packages = ["src/tabpfn_extensions", "tests"]

show_error_codes = true
Expand Down Expand Up @@ -368,7 +367,7 @@ ignore_missing_imports = true
[tool.pyright]
include = ["src", "tests"]

pythonVersion = "3.9"
pythonVersion = "3.10"
typeCheckingMode = "strict"

strictListInference = true
Expand Down
33 changes: 18 additions & 15 deletions src/tabpfn_extensions/hpo/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,24 @@ def get_param_grid_hyperopt(
if model_dir is None:
model_dir = get_cache_dir()

if task_type == "multiclass" and model_version == ModelVersion.V2:
model_source = ModelSource.get_classifier_v2()
elif task_type == "multiclass" and model_version == ModelVersion.V2_5:
model_source = ModelSource.get_classifier_v2_5()
elif task_type == "regression":
if model_version == ModelVersion.V2:
model_source = ModelSource.get_regressor_v2()
elif model_version == ModelVersion.V2_5:
model_source = ModelSource.get_regressor_v2_5()
# Resolve the model source for the (task_type, model_version) pair. Any
# combination we don't explicitly support below raises before we try to
# use `model_source`.
model_source_lookup = {
("multiclass", ModelVersion.V2): ModelSource.get_classifier_v2,
("multiclass", ModelVersion.V2_5): ModelSource.get_classifier_v2_5,
("regression", ModelVersion.V2): ModelSource.get_regressor_v2,
("regression", ModelVersion.V2_5): ModelSource.get_regressor_v2_5,
}
try:
model_source = model_source_lookup[(task_type, model_version)]()
except KeyError as err:
raise NotImplementedError(
f"No hpo search space is defined for task type {task_type!r} and "
f"model version {model_version!r}."
) from err

if task_type == "regression":
search_space["inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS"] = hp.choice(
"REGRESSION_Y_PREPROCESS_TRANSFORMS",
[
Expand All @@ -162,12 +171,6 @@ def get_param_grid_hyperopt(
],
)

else:
raise NotImplementedError(
f"No hpo search space is defined for task type {task_type} and "
f"model version {model_version}."
)

# Make sure models are downloaded.
if download_models_if_missing:
for ckpt_name in model_source.filenames:
Expand Down
5 changes: 3 additions & 2 deletions src/tabpfn_extensions/hpo/tuned_tabpfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from enum import Enum
from typing import Any, Callable
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -183,7 +184,7 @@ def _optimize(self, X: np.ndarray, y: np.ndarray, task_type: str):
for k, v_item in self.search_space.items():
if isinstance(v_item, list):
custom_space[k] = hp.choice(k, v_item)
elif isinstance(v_item, (int, float, bool, str)) or v_item is None:
elif isinstance(v_item, int | float | bool | str) or v_item is None:
custom_space[k] = v_item
else:
custom_space[k] = v_item
Expand Down
3 changes: 2 additions & 1 deletion src/tabpfn_extensions/interpretability/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

from __future__ import annotations

from collections.abc import Callable
from multiprocessing import Pool
from typing import Any, Callable
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion src/tabpfn_extensions/interpretability/shapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class index will be set to 1 per default for classification models. This argumen
data = data.values

# make labels to array if it is a pandas Series
if isinstance(labels, (pd.Series, pd.DataFrame)):
if isinstance(labels, pd.Series | pd.DataFrame):
labels = labels.values

# TabPFNExplainer is directly available in the shapiq module
Expand Down
4 changes: 1 addition & 3 deletions src/tabpfn_extensions/many_class/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import sys
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
Expand All @@ -15,8 +14,7 @@

def _dataclass_kwargs() -> dict[str, Any]:
kwargs: dict[str, Any] = {}
if sys.version_info >= (3, 10): # pragma: no cover - environment dependent
kwargs["slots"] = True
kwargs["slots"] = True
return kwargs


Expand Down
7 changes: 3 additions & 4 deletions src/tabpfn_extensions/misc/sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import platform
import sys
import types
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Callable, Literal
from typing import Literal

import sklearn
from sklearn.utils.fixes import parse_version
Expand All @@ -34,8 +35,6 @@

# tags infrastructure
def _dataclass_args():
if sys.version_info < (3, 10):
return {}
return {"slots": True}


Expand Down Expand Up @@ -196,7 +195,7 @@ def _is_fitted(estimator, attributes=None, all_or_any=all):
Whether the estimator is fitted.
"""
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
if not isinstance(attributes, list | tuple):
attributes = [attributes]
return all_or_any([hasattr(estimator, attr) for attr in attributes])

Expand Down
53 changes: 36 additions & 17 deletions src/tabpfn_extensions/post_hoc_ensembles/sklearn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Literal

import numpy as np
Expand Down Expand Up @@ -198,10 +199,25 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series):
This method should be called from the child class's fit method after validation.
"""
from autogluon.tabular import TabularPredictor
from autogluon.tabular.models import TabPFNV2Model
from autogluon.tabular.models import RealTabPFNv2Model, RealTabPFNv25Model

from tabpfn_extensions.post_hoc_ensembles.utils import search_space_func

# Route to the AutoGluon TabPFN model class that matches the requested
# TabPFN model version. Each class ships with the correct per-version
# max_rows/max_features/max_classes limits, so we no longer need to
# override them via ag_args_fit.
if self.model_version == ModelVersion.V2:
ag_model_class = RealTabPFNv2Model
elif self.model_version == ModelVersion.V2_5:
ag_model_class = RealTabPFNv25Model
else:
raise NotImplementedError(
f"AutoTabPFN does not support TabPFN model version "
f"{self.model_version.value!r} yet. Supported versions: "
f"{ModelVersion.V2.value!r}, {ModelVersion.V2_5.value!r}.",
)

if isinstance(X, pd.DataFrame):
training_df = X.copy()
self._column_names = X.columns.tolist()
Expand Down Expand Up @@ -247,30 +263,33 @@ def fit(self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.Series):
**self.get_task_args_(),
}

def _patch_ag_args_fit_inplace(config: dict[str, Any]) -> None:
"""Patch AutoGluon's per-model params_aux for TabPFN sub-models.

- Disable AutoGluon's static max_rows / max_features / max_classes
asserts so TabPFN's own per-checkpoint validation is the single
authority. TODO: Fix upstream in AutoGluon's `TabPFNV2Model`. A
single class handles all v2.x checkpoints with v2-era limits
hardcoded in `_get_default_auxiliary_params`, which is wrong for
v2.5+. Until that lands, we override per sub-model here.
- Forward the user's `ignore_pretraining_limits` flag to TabPFN.
def _adapt_config_for_autogluon_inplace(config: dict[str, Any]) -> None:
"""Forward `ignore_pretraining_limits` and translate `model_path`.

- AutoGluon 1.5 expects the checkpoint to be passed as
``zip_model_path=[classification_ckpt, regression_ckpt]`` (just
the filenames; AG joins with its own resolved cache dir). The
search space produces a TabPFN-compatible ``model_path=<abs path>``
so that it can also be consumed directly by core TabPFN; here we
translate it to the form AG expects and drop the original key.
- Forward the user's ``ignore_pretraining_limits`` flag into AG's
``ag_args_fit.ignore_constraints``.
"""
ag_args_fit = config.setdefault("ag_args_fit", {})
ag_args_fit["max_rows"] = None
ag_args_fit["max_features"] = None
ag_args_fit["max_classes"] = None
ag_args_fit["ignore_constraints"] = self.ignore_pretraining_limits

full_path = config.pop("model_path", None)
if full_path is not None:
ckpt_name = Path(full_path).name
config["zip_model_path"] = [ckpt_name, ckpt_name]

if isinstance(tabpfn_configs, list):
for cfg in tabpfn_configs:
_patch_ag_args_fit_inplace(cfg)
_adapt_config_for_autogluon_inplace(cfg)
else:
_patch_ag_args_fit_inplace(tabpfn_configs)
_adapt_config_for_autogluon_inplace(tabpfn_configs)

hyperparameters = {TabPFNV2Model: tabpfn_configs}
hyperparameters = {ag_model_class: tabpfn_configs}
if isinstance(self.presets, str) and self.presets == "extreme_quality":
raise ValueError(
"Extreme quality preset is not supported at the moment, as it does not "
Expand Down
14 changes: 4 additions & 10 deletions src/tabpfn_extensions/post_hoc_ensembles/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,10 @@ def prepare_tabpfnv2_config(
else:
config.pop("balance_probabilities", None)

# TODO: Enable RF-PFN at some point
config["model_type"] = "single"

# Special case for dt_pfn
# TODO: This code is unused until we support RF-PFN
if config.get("model_type") == "dt_pfn":
config["n_ensemble_repeats"] = config["n_estimators"]
config["n_estimators"] = 1

# Remove deprecated keys
# Remove search-space-only keys that are not valid TabPFNClassifier kwargs.
# `model_type` is a placeholder for future RF-PFN routing (currently unused),
# and `max_depth` was for the dt_pfn variant.
config.pop("model_type", None)
Comment thread
adrian-prior marked this conversation as resolved.
config.pop("max_depth", None)

return config
Expand Down
2 changes: 1 addition & 1 deletion src/tabpfn_extensions/pval_crt/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def tabpfn_crt(
# ---------------------------
# Multi-feature support
# ---------------------------
if isinstance(j, Sequence) and not isinstance(j, (str, bytes)):
if isinstance(j, Sequence) and not isinstance(j, str | bytes):
results = {}

for feat in j:
Expand Down
6 changes: 3 additions & 3 deletions src/tabpfn_extensions/pval_crt/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Union
from typing import Any

import numpy as np
import pandas as pd
import torch

FeatureSpec = Union[int, str]
FeatureSpec = int | str


def coerce_X_y_to_numpy(
Expand All @@ -23,7 +23,7 @@ def coerce_X_y_to_numpy(
else:
X_np = np.asarray(X)

if pd is not None and isinstance(y, (pd.Series, pd.DataFrame)):
if pd is not None and isinstance(y, pd.Series | pd.DataFrame):
y_np = np.asarray(y).reshape(-1)
else:
y_np = np.asarray(y).reshape(-1)
Expand Down
1 change: 1 addition & 0 deletions src/tabpfn_extensions/sklearn_ensembles/meta_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_tabpfn_outer_ensemble(config: configs.TabPFNConfig, **kwargs):
for ensemble_member_index, sub_config in zip(
range(config.n_estimators),
relevant_config_product,
strict=False,
):
member_config = copy.deepcopy(config)
for k, v in sub_config.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def fit(self, X, y):
# Prune classifiers with weights below the threshold
pruned_classifiers = []
pruned_weights = []
for clf, weight in zip(self.estimators, weights):
for clf, weight in zip(self.estimators, weights, strict=True):
if weight >= self.weight_threshold:
pruned_classifiers.append(clf)
pruned_weights.append(weight)
Expand All @@ -40,7 +40,7 @@ def fit(self, X, y):
pruned_classifiers = [
clf
for _, clf in sorted(
zip(pruned_weights, pruned_classifiers),
zip(pruned_weights, pruned_classifiers, strict=True),
key=lambda pair: pair[0],
)
]
Expand Down
Loading