Skip to content

Commit 401d96d

Browse files
committed
Address gemini and Copilot review on PR #923
Fixes the medium-severity comments raised on the differentiable_input regressor path: 1. Feature instances per column: replace `[Feature(...)] * n_features` with a list comprehension so each column has its own dataclass and a later in-place update on one column does not leak across all columns. 2. y stats numerical robustness: switch `y_float.std()` (PyTorch's default `correction=1`, which differs from `np.std` and returns NaN for N=1) to `clamp(y_float.std(correction=0), min=1e-20)`. This matches the standard `fit()` path's `np.std` semantics and stays finite for single-sample input. 3. Constant-target guard: a constant y collapses the bardist borders to a single point and trips `FullSupportBarDistribution`'s strictly-increasing assertion. `fit()` short-circuits this with `is_constant_target_`; the differentiable path has no analogue, so reject up front with a clear ValueError pointing users at `fit()`. 4. Sequential preprocessing for diff input: force `n_preprocessing_jobs=1` inside `fit_with_differentiable_input`. When X carries an autograd graph, joblib's process-boundary pickling breaks the graph; sequential execution preserves it. The detach-then-`.item()` of `y_train_mean_/std_` is intentional and not changed: `raw_space_bardist_` is a frozen lookup buffer that should not hold a y-grad graph; users wanting fully differentiable target scaling should z-normalise y externally so mean/std become constants here. Documented inline. New tests: - feature_schema_columns_are_independent: catches the alias bug. - std_matches_population_definition: locks in `np.std` semantics. - constant_target_rejected: locks in the explicit guard. - single_sample_y_does_not_nan: confirms N=1 hits the guard cleanly rather than producing NaN deep in the bardist. All 9 differentiable_input tests pass on CPU and CUDA.
1 parent 2b91390 commit 401d96d

2 files changed

Lines changed: 102 additions & 3 deletions

File tree

src/tabpfn/regressor.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,24 @@ def _refresh_targets_for_differentiable_input(
667667
y, dtype=torch.float32
668668
)
669669
y_mean = y_float.mean()
670-
y_std = y_float.std() + 1e-20
670+
# Match the standard fit path's np.std (population std, ddof=0).
671+
# torch.std defaults to correction=1 (sample std), which differs from
672+
# numpy and returns NaN for N=1; clamp keeps the divisor non-zero.
673+
y_std = torch.clamp(y_float.std(correction=0), min=1e-20)
674+
# Constant targets would collapse the bardist borders to a single
675+
# point; the differentiable path has no analogue of fit()'s
676+
# is_constant_target_ short-circuit, so reject up front.
677+
if y_std.detach().item() <= 1e-12:
678+
raise ValueError(
679+
"Constant or near-constant target (std≈0) is not supported "
680+
"by fit_with_differentiable_input; there is no signal to "
681+
"predict differentiably. Use fit() for constant-target data."
682+
)
683+
# Detach when storing as Python floats — raw_space_bardist_ is a
684+
# frozen lookup table and must not hold a y-grad graph. Users who
685+
# need fully differentiable target scaling should z-normalise y
686+
# themselves before calling fit_with_differentiable_input so the
687+
# mean/std are constants here.
671688
self.y_train_mean_ = y_mean.detach().item()
672689
self.y_train_std_ = y_std.detach().item()
673690
y_normalized = (y_float - y_mean) / y_std
@@ -705,7 +722,13 @@ def _initialize_for_differentiable_input(
705722
"Categorical features are not supported for differentiable input."
706723
)
707724
n_features = X.shape[1]
708-
features = [Feature(name=None, modality=FeatureModality.NUMERICAL)] * n_features
725+
# One Feature instance per column — list multiplication would share
726+
# the same dataclass and any later in-place update would leak across
727+
# columns.
728+
features = [
729+
Feature(name=None, modality=FeatureModality.NUMERICAL)
730+
for _ in range(n_features)
731+
]
709732
self.inferred_feature_schema_ = FeatureSchema(features=features)
710733
self.n_features_in_ = n_features
711734

@@ -920,12 +943,15 @@ def fit_with_differentiable_input(
920943
# targets. The model load and ensemble configs stay cached.
921944
X, y = self._refresh_targets_for_differentiable_input(X, y)
922945

946+
# Force sequential preprocessing: with differentiable input, X carries
947+
# an autograd graph that does not survive joblib's process-boundary
948+
# pickling. Sequential execution preserves the graph in-process.
923949
self.ensemble_preprocessor_ = TabPFNEnsemblePreprocessor(
924950
configs=ensemble_configs,
925951
n_samples=X.shape[0],
926952
feature_schema=self.inferred_feature_schema_,
927953
random_state=static_seed,
928-
n_preprocessing_jobs=self.n_preprocessing_jobs,
954+
n_preprocessing_jobs=1,
929955
feature_subsampling_method=FeatureSubsamplingMethod(
930956
self.inference_config_.FEATURE_SUBSAMPLING_METHOD
931957
),

tests/test_regressor_interface.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,79 @@ def test__fit_with_differentiable_input__categorical_features_rejected() -> None
10521052
reg.fit_with_differentiable_input(X, y)
10531053

10541054

1055+
def test__fit_with_differentiable_input__constant_target_rejected() -> None:
1056+
"""A constant-target y has no signal to predict differentiably and would
1057+
collapse the bardist borders; reject with a clear error."""
1058+
reg = TabPFNRegressor(
1059+
n_estimators=1,
1060+
ignore_pretraining_limits=True,
1061+
device="cpu",
1062+
differentiable_input=True,
1063+
)
1064+
X = torch.randn(5, 4)
1065+
y = torch.full((5,), 3.14)
1066+
with pytest.raises(ValueError, match="Constant or near-constant target"):
1067+
reg.fit_with_differentiable_input(X, y)
1068+
1069+
1070+
def test__fit_with_differentiable_input__single_sample_y_does_not_nan() -> None:
1071+
"""torch.std defaults to sample std (correction=1) which returns NaN for
1072+
N=1. Our path uses correction=0 (population std) so std is well defined
1073+
even for a single sample (it just collapses to 0, which then trips the
1074+
constant-target guard — what we want). Verify the failure mode is the
1075+
explicit ValueError, not a downstream NaN."""
1076+
reg = TabPFNRegressor(
1077+
n_estimators=1,
1078+
ignore_pretraining_limits=True,
1079+
device="cpu",
1080+
differentiable_input=True,
1081+
)
1082+
X = torch.randn(1, 4)
1083+
y = torch.tensor([2.0])
1084+
with pytest.raises(ValueError, match="Constant or near-constant target"):
1085+
reg.fit_with_differentiable_input(X, y)
1086+
1087+
1088+
def test__fit_with_differentiable_input__std_matches_population_definition() -> None:
1089+
"""The differentiable path's y_train_std_ should match np.std (population
1090+
std, ddof=0), not torch's default sample std (correction=1), so it lines
1091+
up with the standard fit() path."""
1092+
reg = TabPFNRegressor(
1093+
n_estimators=1,
1094+
ignore_pretraining_limits=True,
1095+
device="cpu",
1096+
differentiable_input=True,
1097+
)
1098+
X = torch.randn(20, 4)
1099+
y_np = np.random.default_rng(0).standard_normal(20).astype(np.float32)
1100+
y = torch.from_numpy(y_np)
1101+
reg.fit_with_differentiable_input(X, y)
1102+
expected = float(np.std(y_np)) # ddof=0
1103+
assert abs(reg.y_train_std_ - expected) < 1e-5, (
1104+
f"y_train_std_ should equal np.std(y) (population std); "
1105+
f"got {reg.y_train_std_}, expected {expected}"
1106+
)
1107+
1108+
1109+
def test__fit_with_differentiable_input__feature_schema_columns_are_independent() -> None:
1110+
"""Each column's Feature must be a distinct instance — list multiplication
1111+
`[Feature(...)] * n` would alias all columns to one mutable dataclass."""
1112+
reg = TabPFNRegressor(
1113+
n_estimators=1,
1114+
ignore_pretraining_limits=True,
1115+
device="cpu",
1116+
differentiable_input=True,
1117+
)
1118+
X = torch.randn(10, 4)
1119+
y = torch.randn(10)
1120+
reg.fit_with_differentiable_input(X, y)
1121+
feats = reg.inferred_feature_schema_.features
1122+
assert len(feats) == 4
1123+
# Distinct instances, not aliases.
1124+
ids = {id(f) for f in feats}
1125+
assert len(ids) == 4, "feature columns share the same Feature instance"
1126+
1127+
10551128
def test__fit_with_differentiable_input__second_call_refreshes_target_stats() -> None:
10561129
"""A second call with different y must update y_train_mean_/std_ and the
10571130
raw_space_bardist_; only the model load and ensemble configs are cached."""

0 commit comments

Comments
 (0)