Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/890.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `get_inference_config()` method to `TabPFNClassifier` and `TabPFNRegressor`. This method loads the model checkpoint if needed and returns the active `InferenceConfig`, allowing inspection of preprocessing and inference settings before calling `fit()`.
1 change: 1 addition & 0 deletions changelog/890.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove unused v2.6 defaults from `InferenceConfig.get_default()`. V2.6 checkpoints always embed their own `InferenceConfig`, so these defaults were never used at inference time. The v2.6 preprocessor config factories are also removed from `tabpfn.preprocessing`.
14 changes: 14 additions & 0 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,20 @@ def model_(self) -> Architecture:
)
return self.models_[0]

def get_inference_config(self) -> InferenceConfig:
"""Load the model if needed and return the active inference config.

Loads the model checkpoint without requiring fit data so the config can be
inspected before calling `fit()`. Any ``inference_config`` override
passed to the constructor is considered.

Returns:
A deep copy of the active inference config.
"""
if not hasattr(self, "inference_config_"):
self._initialize_model_variables()
return copy.deepcopy(self.inference_config_)

# TODO: We can remove this from scikit-learn lower bound of 1.6
def _more_tags(self) -> dict[str, Any]:
return {
Expand Down
42 changes: 0 additions & 42 deletions src/tabpfn/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
PreprocessorConfig,
v2_5_classifier_preprocessor_configs,
v2_5_regressor_preprocessor_configs,
v2_6_classifier_preprocessor_configs,
v2_6_regressor_preprocessor_configs,
v2_classifier_preprocessor_configs,
v2_regressor_preprocessor_configs,
)
Expand Down Expand Up @@ -260,18 +258,6 @@ def get_default(
return _get_v2_5_config(v2_5_classifier_preprocessor_configs())
if task_type == "regression":
return _get_v2_5_config(v2_5_regressor_preprocessor_configs())
elif model_version == ModelVersion.V2_6:
if task_type == "multiclass":
return _get_v2_6_config(
preprocessor_configs=v2_6_classifier_preprocessor_configs(),
task_type=task_type,
)
if task_type == "regression":
return _get_v2_6_config(
preprocessor_configs=v2_6_regressor_preprocessor_configs(),
task_type=task_type,
)

raise ValueError(
f"No inference config is configured for {model_version=}. "
"Please make sure you are using a correct model checkpoint that contains "
Expand Down Expand Up @@ -327,31 +313,3 @@ def _get_v2_5_config(preprocessor_configs: list[PreprocessorConfig]) -> Inferenc
_REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD=None,
_CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD=12.0,
)


def _get_v2_6_config(
preprocessor_configs: list[PreprocessorConfig],
task_type: TaskType,
) -> InferenceConfig:
return InferenceConfig(
MAX_UNIQUE_FOR_CATEGORICAL_FEATURES=30,
MIN_UNIQUE_FOR_NUMERICAL_FEATURES=4,
MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE=100,
OUTLIER_REMOVAL_STD=None,
FEATURE_SHIFT_METHOD="shuffle",
CLASS_SHIFT_METHOD="shuffle",
FINGERPRINT_FEATURE=True,
POLYNOMIAL_FEATURES="no" if task_type == "multiclass" else 10,
SUBSAMPLE_SAMPLES=None,
FEATURE_SUBSAMPLING_METHOD="random",
CONSTANT_FEATURE_COUNT=50,
PREPROCESS_TRANSFORMS=preprocessor_configs,
REGRESSION_Y_PREPROCESS_TRANSFORMS=("none",),
USE_SKLEARN_16_DECIMAL_PRECISION=False,
MAX_NUMBER_OF_CLASSES=10,
MAX_NUMBER_OF_FEATURES=2000,
MAX_NUMBER_OF_SAMPLES=50_000,
FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM=True,
_REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD=None,
_CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD=12.0,
)
4 changes: 0 additions & 4 deletions src/tabpfn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from .presets import (
v2_5_classifier_preprocessor_configs,
v2_5_regressor_preprocessor_configs,
v2_6_classifier_preprocessor_configs,
v2_6_regressor_preprocessor_configs,
v2_classifier_preprocessor_configs,
v2_regressor_preprocessor_configs,
)
Expand All @@ -36,8 +34,6 @@
"generate_regression_ensemble_configs",
"v2_5_classifier_preprocessor_configs",
"v2_5_regressor_preprocessor_configs",
"v2_6_classifier_preprocessor_configs",
"v2_6_regressor_preprocessor_configs",
"v2_classifier_preprocessor_configs",
"v2_regressor_preprocessor_configs",
]
42 changes: 0 additions & 42 deletions src/tabpfn/preprocessing/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,52 +80,10 @@ def v2_5_regressor_preprocessor_configs() -> list[PreprocessorConfig]:
]


def v2_6_classifier_preprocessor_configs() -> list[PreprocessorConfig]:
"""Get the preprocessor configuration for classification in v2.5 of the model."""
return [
PreprocessorConfig(
name="quantile_uni",
append_original=False,
categorical_name="numeric",
global_transformer_name=None,
max_features_per_estimator=680,
),
PreprocessorConfig(
name="quantile_uni",
append_original=False,
categorical_name="ordinal_very_common_categories_shuffled",
global_transformer_name="svd_quarter_components",
max_features_per_estimator=500,
),
]


def v2_6_regressor_preprocessor_configs() -> list[PreprocessorConfig]:
"""Get the preprocessor configuration for regression in v2.5 of the model."""
return [
PreprocessorConfig(
name="quantile_uni",
append_original=False,
categorical_name="numeric",
global_transformer_name=None,
max_features_per_estimator=680,
),
PreprocessorConfig(
name="quantile_uni",
append_original="auto",
categorical_name="ordinal_very_common_categories_shuffled",
global_transformer_name="svd_quarter_components",
max_features_per_estimator=500,
),
]


__all__ = [
"_V2_FEATURE_SUBSAMPLING_THRESHOLD",
"v2_5_classifier_preprocessor_configs",
"v2_5_regressor_preprocessor_configs",
"v2_6_classifier_preprocessor_configs",
"v2_6_regressor_preprocessor_configs",
"v2_classifier_preprocessor_configs",
"v2_regressor_preprocessor_configs",
]
15 changes: 15 additions & 0 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import copy
import logging
import typing
import warnings
Expand Down Expand Up @@ -581,6 +582,20 @@ def bardist_(self, value: FullSupportBarDistribution) -> None:
)
self.znorm_space_bardist_ = value

def get_inference_config(self) -> InferenceConfig:
"""Load the model if needed and return the active inference config.

Loads the model checkpoint without requiring fit data so the config can be
inspected before calling `fit()`. Any ``inference_config`` override
passed to the constructor is considered.

Returns:
A deep copy of the active inference config.
"""
if not hasattr(self, "inference_config_"):
self._initialize_model_variables()
return copy.deepcopy(self.inference_config_)

# TODO: We can remove this from scikit-learn lower bound of 1.6
def _more_tags(self) -> dict[str, Any]:
return {
Expand Down
110 changes: 110 additions & 0 deletions tests/test_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import torch

from tabpfn import TabPFNClassifier, TabPFNRegressor
from tabpfn.architectures import base
from tabpfn.architectures.base.bar_distribution import FullSupportBarDistribution
from tabpfn.architectures.base.config import ModelConfig
from tabpfn.base import ClassifierModelSpecs, RegressorModelSpecs
from tabpfn.constants import ModelVersion
from tabpfn.inference_config import InferenceConfig
from tabpfn.preprocessing import PreprocessorConfig
Expand Down Expand Up @@ -69,3 +74,108 @@ def test__override_with_user_input__override_is_None__returns_copy_of_config() -
new_config = config.override_with_user_input_and_resolve_auto(user_config=None)
assert new_config is not config
assert new_config == config


def _make_classifier_specs() -> ClassifierModelSpecs:
config = ModelConfig(
emsize=8,
features_per_group=1,
max_num_classes=10,
nhead=2,
nlayers=2,
remove_duplicate_features=True,
num_buckets=100,
)
model = base.get_architecture(config=config, cache_trainset_representation=False)
inference_config = InferenceConfig.get_default(
task_type="multiclass", model_version=ModelVersion.V2_5
)
return ClassifierModelSpecs(
model=model,
architecture_config=config,
inference_config=inference_config,
)


def _make_regressor_specs() -> RegressorModelSpecs:
config = ModelConfig(
emsize=8,
features_per_group=1,
max_num_classes=10,
nhead=2,
nlayers=2,
remove_duplicate_features=True,
num_buckets=100,
)
model = base.get_architecture(config=config, cache_trainset_representation=False)
borders = torch.linspace(-3, 3, config.num_buckets + 1)
norm_criterion = FullSupportBarDistribution(borders)
inference_config = InferenceConfig.get_default(
task_type="regression", model_version=ModelVersion.V2_5
)
return RegressorModelSpecs(
model=model,
architecture_config=config,
inference_config=inference_config,
norm_criterion=norm_criterion,
)


def test__classifier_get_inference_config__before_fit__returns_config() -> None:
specs = _make_classifier_specs()
clf = TabPFNClassifier(model_path=specs, device="cpu")
assert not hasattr(clf, "inference_config_")
config = clf.get_inference_config()
assert isinstance(config, InferenceConfig)
assert config == specs.inference_config


def test__classifier_get_inference_config__returns_deepcopy() -> None:
specs = _make_classifier_specs()
clf = TabPFNClassifier(model_path=specs, device="cpu")
config = clf.get_inference_config()
assert config is not clf.inference_config_
config.PREPROCESS_TRANSFORMS.clear()
assert len(clf.inference_config_.PREPROCESS_TRANSFORMS) > 0


def test__classifier_get_inference_config__with_override__applies_override() -> None:
specs = _make_classifier_specs()
clf = TabPFNClassifier(
model_path=specs,
device="cpu",
inference_config={"POLYNOMIAL_FEATURES": "all"},
)
config = clf.get_inference_config()
assert config.POLYNOMIAL_FEATURES == "all"
assert specs.inference_config.POLYNOMIAL_FEATURES == "no"


def test__regressor_get_inference_config__before_fit__returns_config() -> None:
specs = _make_regressor_specs()
reg = TabPFNRegressor(model_path=specs, device="cpu")
assert not hasattr(reg, "inference_config_")
config = reg.get_inference_config()
assert isinstance(config, InferenceConfig)
assert config == specs.inference_config


def test__regressor_get_inference_config__returns_deepcopy() -> None:
specs = _make_regressor_specs()
reg = TabPFNRegressor(model_path=specs, device="cpu")
config = reg.get_inference_config()
assert config is not reg.inference_config_
config.PREPROCESS_TRANSFORMS.clear()
assert len(reg.inference_config_.PREPROCESS_TRANSFORMS) > 0


def test__regressor_get_inference_config__with_override__applies_override() -> None:
specs = _make_regressor_specs()
reg = TabPFNRegressor(
model_path=specs,
device="cpu",
inference_config={"POLYNOMIAL_FEATURES": "all"},
)
config = reg.get_inference_config()
assert config.POLYNOMIAL_FEATURES == "all"
assert specs.inference_config.POLYNOMIAL_FEATURES == "no"
Loading