Skip to content

Commit 0dff55b

Browse files
authored
Introduces API to get inference config, removes unused inference config defaults (#890)
1 parent b24c02f commit 0dff55b

8 files changed

Lines changed: 141 additions & 88 deletions

File tree

changelog/890.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `get_inference_config()` method to `TabPFNClassifier` and `TabPFNRegressor`. This method loads the model checkpoint if needed and returns the active `InferenceConfig`, allowing inspection of preprocessing and inference settings before calling `fit()`.

changelog/890.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Remove unused v2.6 defaults from `InferenceConfig.get_default()`. V2.6 checkpoints always embed their own `InferenceConfig`, so these defaults were never used at inference time. The v2.6 preprocessor config factories are also removed from `tabpfn.preprocessing`.

src/tabpfn/classifier.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,20 @@ def model_(self) -> Architecture:
551551
)
552552
return self.models_[0]
553553

554+
def get_inference_config(self) -> InferenceConfig:
555+
"""Load the model if needed and return the active inference config.
556+
557+
Loads the model checkpoint without requiring fit data so the config can be
558+
inspected before calling `fit()`. Any ``inference_config`` override
559+
passed to the constructor is considered.
560+
561+
Returns:
562+
A deep copy of the active inference config.
563+
"""
564+
if not hasattr(self, "inference_config_"):
565+
self._initialize_model_variables()
566+
return copy.deepcopy(self.inference_config_)
567+
554568
# TODO: We can remove this from scikit-learn lower bound of 1.6
555569
def _more_tags(self) -> dict[str, Any]:
556570
return {

src/tabpfn/inference_config.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
PreprocessorConfig,
1616
v2_5_classifier_preprocessor_configs,
1717
v2_5_regressor_preprocessor_configs,
18-
v2_6_classifier_preprocessor_configs,
19-
v2_6_regressor_preprocessor_configs,
2018
v2_classifier_preprocessor_configs,
2119
v2_regressor_preprocessor_configs,
2220
)
@@ -260,18 +258,6 @@ def get_default(
260258
return _get_v2_5_config(v2_5_classifier_preprocessor_configs())
261259
if task_type == "regression":
262260
return _get_v2_5_config(v2_5_regressor_preprocessor_configs())
263-
elif model_version == ModelVersion.V2_6:
264-
if task_type == "multiclass":
265-
return _get_v2_6_config(
266-
preprocessor_configs=v2_6_classifier_preprocessor_configs(),
267-
task_type=task_type,
268-
)
269-
if task_type == "regression":
270-
return _get_v2_6_config(
271-
preprocessor_configs=v2_6_regressor_preprocessor_configs(),
272-
task_type=task_type,
273-
)
274-
275261
raise ValueError(
276262
f"No inference config is configured for {model_version=}. "
277263
"Please make sure you are using a correct model checkpoint that contains "
@@ -327,31 +313,3 @@ def _get_v2_5_config(preprocessor_configs: list[PreprocessorConfig]) -> Inferenc
327313
_REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD=None,
328314
_CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD=12.0,
329315
)
330-
331-
332-
def _get_v2_6_config(
333-
preprocessor_configs: list[PreprocessorConfig],
334-
task_type: TaskType,
335-
) -> InferenceConfig:
336-
return InferenceConfig(
337-
MAX_UNIQUE_FOR_CATEGORICAL_FEATURES=30,
338-
MIN_UNIQUE_FOR_NUMERICAL_FEATURES=4,
339-
MIN_NUMBER_SAMPLES_FOR_CATEGORICAL_INFERENCE=100,
340-
OUTLIER_REMOVAL_STD=None,
341-
FEATURE_SHIFT_METHOD="shuffle",
342-
CLASS_SHIFT_METHOD="shuffle",
343-
FINGERPRINT_FEATURE=True,
344-
POLYNOMIAL_FEATURES="no" if task_type == "multiclass" else 10,
345-
SUBSAMPLE_SAMPLES=None,
346-
FEATURE_SUBSAMPLING_METHOD="random",
347-
CONSTANT_FEATURE_COUNT=50,
348-
PREPROCESS_TRANSFORMS=preprocessor_configs,
349-
REGRESSION_Y_PREPROCESS_TRANSFORMS=("none",),
350-
USE_SKLEARN_16_DECIMAL_PRECISION=False,
351-
MAX_NUMBER_OF_CLASSES=10,
352-
MAX_NUMBER_OF_FEATURES=2000,
353-
MAX_NUMBER_OF_SAMPLES=50_000,
354-
FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM=True,
355-
_REGRESSION_DEFAULT_OUTLIER_REMOVAL_STD=None,
356-
_CLASSIFICATION_DEFAULT_OUTLIER_REMOVAL_STD=12.0,
357-
)

src/tabpfn/preprocessing/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from .presets import (
1717
v2_5_classifier_preprocessor_configs,
1818
v2_5_regressor_preprocessor_configs,
19-
v2_6_classifier_preprocessor_configs,
20-
v2_6_regressor_preprocessor_configs,
2119
v2_classifier_preprocessor_configs,
2220
v2_regressor_preprocessor_configs,
2321
)
@@ -36,8 +34,6 @@
3634
"generate_regression_ensemble_configs",
3735
"v2_5_classifier_preprocessor_configs",
3836
"v2_5_regressor_preprocessor_configs",
39-
"v2_6_classifier_preprocessor_configs",
40-
"v2_6_regressor_preprocessor_configs",
4137
"v2_classifier_preprocessor_configs",
4238
"v2_regressor_preprocessor_configs",
4339
]

src/tabpfn/preprocessing/presets.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -80,52 +80,10 @@ def v2_5_regressor_preprocessor_configs() -> list[PreprocessorConfig]:
8080
]
8181

8282

83-
def v2_6_classifier_preprocessor_configs() -> list[PreprocessorConfig]:
84-
"""Get the preprocessor configuration for classification in v2.5 of the model."""
85-
return [
86-
PreprocessorConfig(
87-
name="quantile_uni",
88-
append_original=False,
89-
categorical_name="numeric",
90-
global_transformer_name=None,
91-
max_features_per_estimator=680,
92-
),
93-
PreprocessorConfig(
94-
name="quantile_uni",
95-
append_original=False,
96-
categorical_name="ordinal_very_common_categories_shuffled",
97-
global_transformer_name="svd_quarter_components",
98-
max_features_per_estimator=500,
99-
),
100-
]
101-
102-
103-
def v2_6_regressor_preprocessor_configs() -> list[PreprocessorConfig]:
104-
"""Get the preprocessor configuration for regression in v2.5 of the model."""
105-
return [
106-
PreprocessorConfig(
107-
name="quantile_uni",
108-
append_original=False,
109-
categorical_name="numeric",
110-
global_transformer_name=None,
111-
max_features_per_estimator=680,
112-
),
113-
PreprocessorConfig(
114-
name="quantile_uni",
115-
append_original="auto",
116-
categorical_name="ordinal_very_common_categories_shuffled",
117-
global_transformer_name="svd_quarter_components",
118-
max_features_per_estimator=500,
119-
),
120-
]
121-
122-
12383
__all__ = [
12484
"_V2_FEATURE_SUBSAMPLING_THRESHOLD",
12585
"v2_5_classifier_preprocessor_configs",
12686
"v2_5_regressor_preprocessor_configs",
127-
"v2_6_classifier_preprocessor_configs",
128-
"v2_6_regressor_preprocessor_configs",
12987
"v2_classifier_preprocessor_configs",
13088
"v2_regressor_preprocessor_configs",
13189
]

src/tabpfn/regressor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
import copy
2021
import logging
2122
import typing
2223
import warnings
@@ -581,6 +582,20 @@ def bardist_(self, value: FullSupportBarDistribution) -> None:
581582
)
582583
self.znorm_space_bardist_ = value
583584

585+
def get_inference_config(self) -> InferenceConfig:
586+
"""Load the model if needed and return the active inference config.
587+
588+
Loads the model checkpoint without requiring fit data so the config can be
589+
inspected before calling `fit()`. Any ``inference_config`` override
590+
passed to the constructor is considered.
591+
592+
Returns:
593+
A deep copy of the active inference config.
594+
"""
595+
if not hasattr(self, "inference_config_"):
596+
self._initialize_model_variables()
597+
return copy.deepcopy(self.inference_config_)
598+
584599
# TODO: We can remove this from scikit-learn lower bound of 1.6
585600
def _more_tags(self) -> dict[str, Any]:
586601
return {

tests/test_inference_config.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
import torch
99

10+
from tabpfn import TabPFNClassifier, TabPFNRegressor
11+
from tabpfn.architectures import base
12+
from tabpfn.architectures.base.bar_distribution import FullSupportBarDistribution
13+
from tabpfn.architectures.base.config import ModelConfig
14+
from tabpfn.base import ClassifierModelSpecs, RegressorModelSpecs
1015
from tabpfn.constants import ModelVersion
1116
from tabpfn.inference_config import InferenceConfig
1217
from tabpfn.preprocessing import PreprocessorConfig
@@ -69,3 +74,108 @@ def test__override_with_user_input__override_is_None__returns_copy_of_config() -
6974
new_config = config.override_with_user_input_and_resolve_auto(user_config=None)
7075
assert new_config is not config
7176
assert new_config == config
77+
78+
79+
def _make_classifier_specs() -> ClassifierModelSpecs:
80+
config = ModelConfig(
81+
emsize=8,
82+
features_per_group=1,
83+
max_num_classes=10,
84+
nhead=2,
85+
nlayers=2,
86+
remove_duplicate_features=True,
87+
num_buckets=100,
88+
)
89+
model = base.get_architecture(config=config, cache_trainset_representation=False)
90+
inference_config = InferenceConfig.get_default(
91+
task_type="multiclass", model_version=ModelVersion.V2_5
92+
)
93+
return ClassifierModelSpecs(
94+
model=model,
95+
architecture_config=config,
96+
inference_config=inference_config,
97+
)
98+
99+
100+
def _make_regressor_specs() -> RegressorModelSpecs:
101+
config = ModelConfig(
102+
emsize=8,
103+
features_per_group=1,
104+
max_num_classes=10,
105+
nhead=2,
106+
nlayers=2,
107+
remove_duplicate_features=True,
108+
num_buckets=100,
109+
)
110+
model = base.get_architecture(config=config, cache_trainset_representation=False)
111+
borders = torch.linspace(-3, 3, config.num_buckets + 1)
112+
norm_criterion = FullSupportBarDistribution(borders)
113+
inference_config = InferenceConfig.get_default(
114+
task_type="regression", model_version=ModelVersion.V2_5
115+
)
116+
return RegressorModelSpecs(
117+
model=model,
118+
architecture_config=config,
119+
inference_config=inference_config,
120+
norm_criterion=norm_criterion,
121+
)
122+
123+
124+
def test__classifier_get_inference_config__before_fit__returns_config() -> None:
125+
specs = _make_classifier_specs()
126+
clf = TabPFNClassifier(model_path=specs, device="cpu")
127+
assert not hasattr(clf, "inference_config_")
128+
config = clf.get_inference_config()
129+
assert isinstance(config, InferenceConfig)
130+
assert config == specs.inference_config
131+
132+
133+
def test__classifier_get_inference_config__returns_deepcopy() -> None:
134+
specs = _make_classifier_specs()
135+
clf = TabPFNClassifier(model_path=specs, device="cpu")
136+
config = clf.get_inference_config()
137+
assert config is not clf.inference_config_
138+
config.PREPROCESS_TRANSFORMS.clear()
139+
assert len(clf.inference_config_.PREPROCESS_TRANSFORMS) > 0
140+
141+
142+
def test__classifier_get_inference_config__with_override__applies_override() -> None:
143+
specs = _make_classifier_specs()
144+
clf = TabPFNClassifier(
145+
model_path=specs,
146+
device="cpu",
147+
inference_config={"POLYNOMIAL_FEATURES": "all"},
148+
)
149+
config = clf.get_inference_config()
150+
assert config.POLYNOMIAL_FEATURES == "all"
151+
assert specs.inference_config.POLYNOMIAL_FEATURES == "no"
152+
153+
154+
def test__regressor_get_inference_config__before_fit__returns_config() -> None:
155+
specs = _make_regressor_specs()
156+
reg = TabPFNRegressor(model_path=specs, device="cpu")
157+
assert not hasattr(reg, "inference_config_")
158+
config = reg.get_inference_config()
159+
assert isinstance(config, InferenceConfig)
160+
assert config == specs.inference_config
161+
162+
163+
def test__regressor_get_inference_config__returns_deepcopy() -> None:
164+
specs = _make_regressor_specs()
165+
reg = TabPFNRegressor(model_path=specs, device="cpu")
166+
config = reg.get_inference_config()
167+
assert config is not reg.inference_config_
168+
config.PREPROCESS_TRANSFORMS.clear()
169+
assert len(reg.inference_config_.PREPROCESS_TRANSFORMS) > 0
170+
171+
172+
def test__regressor_get_inference_config__with_override__applies_override() -> None:
173+
specs = _make_regressor_specs()
174+
reg = TabPFNRegressor(
175+
model_path=specs,
176+
device="cpu",
177+
inference_config={"POLYNOMIAL_FEATURES": "all"},
178+
)
179+
config = reg.get_inference_config()
180+
assert config.POLYNOMIAL_FEATURES == "all"
181+
assert specs.inference_config.POLYNOMIAL_FEATURES == "no"

0 commit comments

Comments
 (0)