Skip to content

Commit 2b91390

Browse files
committed
Refresh target stats on every fit_with_differentiable_input call
Address gemini-code-assist review on PR #923: the second fit call previously skipped re-normalising y, leaving y_train_mean_, y_train_std_, raw_space_bardist_ stuck on the first fit's stats — silently miscaling predictions when the new target distribution differed. Split _initialize_for_differentiable_input into: - _initialize_for_differentiable_input: first-call-only setup (categorical check, feature schema, ensemble configs). Cached in self.ensemble_configs_. - _refresh_targets_for_differentiable_input: per-call setup (validate_dataset_size, z-normalise y, rebuild raw_space_bardist_, update n_train_samples_). Runs on every fit. fit_with_differentiable_input's else branch now calls the per-call helper so subsequent fits track the current target distribution while still reusing the loaded model and ensemble configs. Add test__fit_with_differentiable_input__second_call_refreshes_target_stats that fits twice with very different y distributions and checks y_train_mean_, y_train_std_, and raw_space_bardist_.borders all move.
1 parent 70d0210 commit 2b91390

2 files changed

Lines changed: 82 additions & 33 deletions

File tree

src/tabpfn/regressor.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -640,33 +640,61 @@ def _initialize_model_variables(self) -> int:
640640
"""
641641
return initialize_model_variables_helper(self, self.estimator_type)
642642

643+
def _refresh_targets_for_differentiable_input(
644+
self, X: torch.Tensor, y: torch.Tensor
645+
) -> tuple[torch.Tensor, torch.Tensor]:
646+
"""Per-fit-call data-dependent setup for the differentiable path.
647+
648+
Validates input shape, z-normalises ``y`` as a torch op (preserves
649+
grads), updates the standardisation stats, and rebuilds
650+
``raw_space_bardist_`` in the caller's current target scale. Run on
651+
every ``fit_with_differentiable_input`` call so the regressor's
652+
target stats always match the data being fit; the model load and
653+
ensemble configs are cached in ``_initialize_for_differentiable_input``
654+
and run only on the first call.
655+
"""
656+
validate_dataset_size(
657+
X=X,
658+
y=y,
659+
max_num_samples=self.inference_config_.MAX_NUMBER_OF_SAMPLES,
660+
max_num_features=self.inference_config_.MAX_NUMBER_OF_FEATURES,
661+
devices=self.devices_,
662+
ignore_pretraining_limits=self.ignore_pretraining_limits,
663+
)
664+
self.n_train_samples_ = int(X.shape[0])
665+
666+
y_float = y.float() if isinstance(y, torch.Tensor) else torch.as_tensor(
667+
y, dtype=torch.float32
668+
)
669+
y_mean = y_float.mean()
670+
y_std = y_float.std() + 1e-20
671+
self.y_train_mean_ = y_mean.detach().item()
672+
self.y_train_std_ = y_std.detach().item()
673+
y_normalized = (y_float - y_mean) / y_std
674+
675+
# raw_space_bardist_ is a constant lookup in the caller's target
676+
# scale; detach so the buffer does not hold onto y's grad graph.
677+
borders = self.znorm_space_bardist_.borders.detach()
678+
self.raw_space_bardist_ = FullSupportBarDistribution(
679+
borders * self.y_train_std_ + self.y_train_mean_,
680+
).float()
681+
return X, y_normalized
682+
643683
def _initialize_for_differentiable_input(
644684
self,
645685
X: torch.Tensor,
646686
y: torch.Tensor,
647687
rng: np.random.Generator,
648688
) -> tuple[list[RegressorEnsembleConfig], torch.Tensor, torch.Tensor]:
649-
"""Initialize the model for differentiable input.
689+
"""First-call setup for the differentiable path.
650690
651691
Mirrors the classifier-side helper so that gradients can flow from a
652692
loss back to upstream torch modules feeding ``X`` (and optionally
653693
``y``). Skips the standard numpy preprocessing path and uses a
654-
differentiable identity preprocessor.
655-
656-
Returns the ensemble configs together with ``X`` and the
657-
z-normalised ``y``. The standardisation parameters are stored on
658-
``self`` so ``raw_space_bardist_`` reflects the caller's target
659-
scale.
694+
differentiable identity preprocessor. Subsequent calls reuse the
695+
feature schema and ensemble configs but re-run target normalization
696+
via ``_refresh_targets_for_differentiable_input``.
660697
"""
661-
validate_dataset_size(
662-
X=X,
663-
y=y,
664-
max_num_samples=self.inference_config_.MAX_NUMBER_OF_SAMPLES,
665-
max_num_features=self.inference_config_.MAX_NUMBER_OF_FEATURES,
666-
devices=self.devices_,
667-
ignore_pretraining_limits=self.ignore_pretraining_limits,
668-
)
669-
670698
# Minimal preprocessing for prompt tuning: no categorical features,
671699
# all-numerical schema, identity preprocessor that preserves grads.
672700
if (
@@ -680,24 +708,8 @@ def _initialize_for_differentiable_input(
680708
features = [Feature(name=None, modality=FeatureModality.NUMERICAL)] * n_features
681709
self.inferred_feature_schema_ = FeatureSchema(features=features)
682710
self.n_features_in_ = n_features
683-
self.n_train_samples_ = int(X.shape[0])
684711

685-
# z-normalise y as a torch op so that gradients flow if y has them.
686-
y_float = y.float() if isinstance(y, torch.Tensor) else torch.as_tensor(
687-
y, dtype=torch.float32
688-
)
689-
y_mean = y_float.mean()
690-
y_std = y_float.std() + 1e-20
691-
self.y_train_mean_ = y_mean.detach().item()
692-
self.y_train_std_ = y_std.detach().item()
693-
y_normalized = (y_float - y_mean) / y_std
694-
695-
# raw_space_bardist_ is a constant lookup in caller's target scale; we
696-
# detach so the buffer does not accidentally hold onto y's grad graph.
697-
borders = self.znorm_space_bardist_.borders.detach()
698-
self.raw_space_bardist_ = FullSupportBarDistribution(
699-
borders * self.y_train_std_ + self.y_train_mean_,
700-
).float()
712+
X, y_normalized = self._refresh_targets_for_differentiable_input(X, y)
701713

702714
preprocessor_configs = [PreprocessorConfig("none", differentiable=True)]
703715
# Polynomial features go through sklearn StandardScaler on numpy and
@@ -903,6 +915,10 @@ def fit_with_differentiable_input(
903915
self.inference_precision, self.devices_
904916
)
905917
ensemble_configs = self.ensemble_configs_ # Reuse from first fit
918+
# Re-validate and re-normalise y for the new fit data so that
919+
# raw_space_bardist_ and y_train_mean_/std_ track the current
920+
# targets. The model load and ensemble configs stay cached.
921+
X, y = self._refresh_targets_for_differentiable_input(X, y)
906922

907923
self.ensemble_preprocessor_ = TabPFNEnsemblePreprocessor(
908924
configs=ensemble_configs,

tests/test_regressor_interface.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,3 +1050,36 @@ def test__fit_with_differentiable_input__categorical_features_rejected() -> None
10501050
y = torch.randn(20)
10511051
with pytest.raises(ValueError, match="Categorical features"):
10521052
reg.fit_with_differentiable_input(X, y)
1053+
1054+
1055+
def test__fit_with_differentiable_input__second_call_refreshes_target_stats() -> None:
1056+
"""A second call with different y must update y_train_mean_/std_ and the
1057+
raw_space_bardist_; only the model load and ensemble configs are cached."""
1058+
torch.manual_seed(0)
1059+
reg = TabPFNRegressor(
1060+
n_estimators=1,
1061+
ignore_pretraining_limits=True,
1062+
device="cpu",
1063+
differentiable_input=True,
1064+
)
1065+
X1 = torch.randn(20, 4)
1066+
y1 = torch.randn(20) * 10.0 + 100.0 # mean ~100, std ~10
1067+
reg.fit_with_differentiable_input(X1, y1)
1068+
mean1, std1 = reg.y_train_mean_, reg.y_train_std_
1069+
bardist_borders1 = reg.raw_space_bardist_.borders.clone()
1070+
1071+
X2 = torch.randn(20, 4)
1072+
y2 = torch.randn(20) * 0.5 - 5.0 # mean ~-5, std ~0.5
1073+
reg.fit_with_differentiable_input(X2, y2)
1074+
mean2, std2 = reg.y_train_mean_, reg.y_train_std_
1075+
1076+
assert abs(mean2 - mean1) > 1.0, (
1077+
f"y_train_mean_ should reflect new y; got {mean1} -> {mean2}"
1078+
)
1079+
assert abs(std2 - std1) > 1.0, (
1080+
f"y_train_std_ should reflect new y; got {std1} -> {std2}"
1081+
)
1082+
# raw_space_bardist_ borders are derived from y stats; they must move.
1083+
assert not torch.allclose(reg.raw_space_bardist_.borders, bardist_borders1), (
1084+
"raw_space_bardist_ must be rebuilt to the new target scale"
1085+
)

0 commit comments

Comments
 (0)