Skip to content

Commit 6affc25

Browse files
committed
feat: add validation_frequency option to fine-tuning (#811)
- Add `validation_frequency: int | None = 1` parameter to `FinetunedTabPFNBase`, `FinetunedTabPFNClassifier`, and `FinetunedTabPFNRegressor` - When `validation_frequency=None`, validation is disabled entirely (also disables early stopping with a UserWarning) - When `validation_frequency=N`, validation runs every N epochs - Default of 1 preserves existing behaviour (validate every epoch) Closes #811
1 parent ea9f11e commit 6affc25

3 files changed

Lines changed: 91 additions & 57 deletions

File tree

src/tabpfn/finetuning/finetuned_base.py

Lines changed: 79 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ class FinetunedTabPFNBase(BaseEstimator, ABC):
136136
data batches. This is helpful in most cases because, e.g., the column order
137137
will stay the same across batches.
138138
If False, the preprocessing will use a different random seed for each batch.
139+
validation_frequency: How often (in epochs) to run validation. If set to
140+
an integer N, validation is run every N epochs. If None, validation is
141+
disabled entirely, which also disables early stopping. Defaults to 1
142+
(validate every epoch).
139143
"""
140144

141145
def __init__( # noqa: PLR0913
@@ -163,6 +167,7 @@ def __init__( # noqa: PLR0913
163167
use_activation_checkpointing: bool = True,
164168
save_checkpoint_interval: int | None = 10,
165169
use_fixed_preprocessing_seed: bool = True,
170+
validation_frequency: int | None = 1,
166171
):
167172
super().__init__()
168173
self.device = device
@@ -188,6 +193,7 @@ def __init__( # noqa: PLR0913
188193
self.save_checkpoint_interval = save_checkpoint_interval
189194
self.meta_batch_size = META_BATCH_SIZE
190195
self.use_fixed_preprocessing_seed = use_fixed_preprocessing_seed
196+
self.validation_frequency = validation_frequency
191197

192198
if self.use_fixed_preprocessing_seed and not (
193199
self.n_estimators_finetune
@@ -528,16 +534,26 @@ def _fit( # noqa: C901,PLR0912
528534
use_amp = self.device.startswith("cuda") and torch.cuda.is_available()
529535
scaler = GradScaler() if use_amp else None # type: ignore
530536

531-
logger.info("--- 🚀 Eval default model ---")
532-
eval_result = self._evaluate_model(
533-
validation_eval_config,
534-
X_train, # pyright: ignore[reportArgumentType]
535-
y_train, # pyright: ignore[reportArgumentType]
536-
X_val, # pyright: ignore[reportArgumentType]
537-
y_val, # pyright: ignore[reportArgumentType]
538-
)
539-
self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None)
540-
best_metric: float = eval_result.primary
537+
if self.validation_frequency is not None:
538+
logger.info("--- 🚀 Eval default model ---")
539+
eval_result = self._evaluate_model(
540+
validation_eval_config,
541+
X_train, # pyright: ignore[reportArgumentType]
542+
y_train, # pyright: ignore[reportArgumentType]
543+
X_val, # pyright: ignore[reportArgumentType]
544+
y_val, # pyright: ignore[reportArgumentType]
545+
)
546+
self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None)
547+
best_metric: float = eval_result.primary
548+
else:
549+
if self.early_stopping:
550+
warnings.warn(
551+
"`early_stopping` is enabled but `validation_frequency` is None. "
552+
"Early stopping requires validation; it will be disabled.",
553+
UserWarning,
554+
stacklevel=2,
555+
)
556+
best_metric = self._get_initial_best_metric()
541557

542558
static_seed, rng = infer_random_state(self.random_state)
543559
preprocessing_random_state = (
@@ -684,61 +700,67 @@ def _fit( # noqa: C901,PLR0912
684700
epoch_loss_sum / epoch_batches if epoch_batches > 0 else None
685701
)
686702

687-
eval_result = self._evaluate_model(
688-
validation_eval_config,
689-
X_train, # pyright: ignore[reportArgumentType]
690-
y_train, # pyright: ignore[reportArgumentType]
691-
X_val, # pyright: ignore[reportArgumentType]
692-
y_val, # pyright: ignore[reportArgumentType]
703+
run_validation = (
704+
self.validation_frequency is not None
705+
and (epoch + 1) % self.validation_frequency == 0
693706
)
694707

695-
self._log_epoch_evaluation(epoch, eval_result, mean_train_loss)
708+
if run_validation:
709+
eval_result = self._evaluate_model(
710+
validation_eval_config,
711+
X_train, # pyright: ignore[reportArgumentType]
712+
y_train, # pyright: ignore[reportArgumentType]
713+
X_val, # pyright: ignore[reportArgumentType]
714+
y_val, # pyright: ignore[reportArgumentType]
715+
)
696716

697-
primary_metric = eval_result.primary
717+
self._log_epoch_evaluation(epoch, eval_result, mean_train_loss)
698718

699-
if output_dir is not None and not np.isnan(primary_metric):
700-
save_interval_checkpoint = (
701-
self.save_checkpoint_interval is not None
702-
and (epoch + 1) % self.save_checkpoint_interval == 0
703-
)
719+
primary_metric = eval_result.primary
704720

705-
is_best = self._is_improvement(primary_metric, best_metric)
706-
707-
if save_interval_checkpoint or is_best:
708-
save_checkpoint(
709-
estimator=self.finetuned_estimator_,
710-
output_dir=output_dir,
711-
epoch=epoch + 1,
712-
optimizer=optimizer,
713-
metrics=self._get_checkpoint_metrics(eval_result),
714-
train_size=train_size,
715-
is_best=is_best,
716-
save_interval_checkpoint=save_interval_checkpoint,
721+
if output_dir is not None and not np.isnan(primary_metric):
722+
save_interval_checkpoint = (
723+
self.save_checkpoint_interval is not None
724+
and (epoch + 1) % self.save_checkpoint_interval == 0
717725
)
718726

719-
if self.early_stopping and not np.isnan(primary_metric):
720-
if self._is_improvement(primary_metric, best_metric):
721-
best_metric = primary_metric
722-
patience_counter = 0
723-
best_model = copy.deepcopy(self.finetuned_estimator_)
724-
else:
725-
patience_counter += 1
726-
logger.info(
727-
"⚠️ No improvement for %s epochs. Best %s: %.4f",
728-
patience_counter,
729-
self._metric_name,
730-
best_metric,
731-
)
727+
is_best = self._is_improvement(primary_metric, best_metric)
728+
729+
if save_interval_checkpoint or is_best:
730+
save_checkpoint(
731+
estimator=self.finetuned_estimator_,
732+
output_dir=output_dir,
733+
epoch=epoch + 1,
734+
optimizer=optimizer,
735+
metrics=self._get_checkpoint_metrics(eval_result),
736+
train_size=train_size,
737+
is_best=is_best,
738+
save_interval_checkpoint=save_interval_checkpoint,
739+
)
732740

733-
if patience_counter >= self.early_stopping_patience:
734-
logger.info(
735-
"🛑 Early stopping triggered. Best %s: %.4f",
736-
self._metric_name,
737-
best_metric,
738-
)
739-
if best_model is not None:
740-
self.finetuned_estimator_ = best_model
741-
break
741+
if self.early_stopping and not np.isnan(primary_metric):
742+
if self._is_improvement(primary_metric, best_metric):
743+
best_metric = primary_metric
744+
patience_counter = 0
745+
best_model = copy.deepcopy(self.finetuned_estimator_)
746+
else:
747+
patience_counter += 1
748+
logger.info(
749+
"⚠️ No improvement for %s epochs. Best %s: %.4f",
750+
patience_counter,
751+
self._metric_name,
752+
best_metric,
753+
)
754+
755+
if patience_counter >= self.early_stopping_patience:
756+
logger.info(
757+
"🛑 Early stopping triggered. Best %s: %.4f",
758+
self._metric_name,
759+
best_metric,
760+
)
761+
if best_model is not None:
762+
self.finetuned_estimator_ = best_model
763+
break
742764

743765
if self.time_limit is not None:
744766
elapsed_time = time.monotonic() - start_time

src/tabpfn/finetuning/finetuned_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ class FinetunedTabPFNClassifier(FinetunedTabPFNBase, ClassifierMixin):
115115
data batches. This is helpful in most cases because, e.g., the column order
116116
will stay the same across batches.
117117
If False, the preprocessing will use a different random seed for each batch.
118+
validation_frequency: How often (in epochs) to run validation. If set to
119+
an integer N, validation is run every N epochs. If None, validation is
120+
disabled entirely, which also disables early stopping. Defaults to 1
121+
(validate every epoch).
118122
119123
FinetunedTabPFNClassifier specific arguments:
120124
@@ -150,6 +154,7 @@ def __init__( # noqa: PLR0913
150154
use_activation_checkpointing: bool = True,
151155
save_checkpoint_interval: int | None = 10,
152156
use_fixed_preprocessing_seed: bool = True,
157+
validation_frequency: int | None = 1,
153158
extra_classifier_kwargs: dict[str, Any] | None = None,
154159
eval_metric: Literal["roc_auc", "log_loss"] | None = None,
155160
):
@@ -176,6 +181,7 @@ def __init__( # noqa: PLR0913
176181
use_activation_checkpointing=use_activation_checkpointing,
177182
save_checkpoint_interval=save_checkpoint_interval,
178183
use_fixed_preprocessing_seed=use_fixed_preprocessing_seed,
184+
validation_frequency=validation_frequency,
179185
)
180186
self.extra_classifier_kwargs = extra_classifier_kwargs
181187
self.eval_metric = eval_metric

src/tabpfn/finetuning/finetuned_regressor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ class FinetunedTabPFNRegressor(FinetunedTabPFNBase, RegressorMixin):
282282
data batches. This is helpful in most cases because, e.g., the column order
283283
will stay the same across batches.
284284
If False, the preprocessing will use a different random seed for each batch.
285+
validation_frequency: How often (in epochs) to run validation. If set to
286+
an integer N, validation is run every N epochs. If None, validation is
287+
disabled entirely, which also disables early stopping. Defaults to 1
288+
(validate every epoch).
285289
286290
FinetunedTabPFNRegressor specific arguments:
287291
@@ -333,6 +337,7 @@ def __init__( # noqa: PLR0913
333337
use_activation_checkpointing: bool = True,
334338
save_checkpoint_interval: int | None = 10,
335339
use_fixed_preprocessing_seed: bool = True,
340+
validation_frequency: int | None = 1,
336341
extra_regressor_kwargs: dict[str, Any] | None = None,
337342
ce_loss_weight: float = 0.0,
338343
crps_loss_weight: float = 1.0,
@@ -366,6 +371,7 @@ def __init__( # noqa: PLR0913
366371
use_activation_checkpointing=use_activation_checkpointing,
367372
save_checkpoint_interval=save_checkpoint_interval,
368373
use_fixed_preprocessing_seed=use_fixed_preprocessing_seed,
374+
validation_frequency=validation_frequency,
369375
)
370376
self.extra_regressor_kwargs = extra_regressor_kwargs
371377
self.eval_metric = eval_metric

0 commit comments

Comments
 (0)