@@ -136,6 +136,10 @@ class FinetunedTabPFNBase(BaseEstimator, ABC):
136136 data batches. This is helpful in most cases because, e.g., the column order
137137 will stay the same across batches.
138138 If False, the preprocessing will use a different random seed for each batch.
139+ validation_frequency: How often (in epochs) to run validation. If set to
140+ an integer N, validation is run every N epochs. If None, validation is
141+ disabled entirely, which also disables early stopping. Defaults to 1
142+ (validate every epoch).
139143 """
140144
141145 def __init__ ( # noqa: PLR0913
@@ -163,6 +167,7 @@ def __init__( # noqa: PLR0913
163167 use_activation_checkpointing : bool = True ,
164168 save_checkpoint_interval : int | None = 10 ,
165169 use_fixed_preprocessing_seed : bool = True ,
170+ validation_frequency : int | None = 1 ,
166171 ):
167172 super ().__init__ ()
168173 self .device = device
@@ -188,6 +193,7 @@ def __init__( # noqa: PLR0913
188193 self .save_checkpoint_interval = save_checkpoint_interval
189194 self .meta_batch_size = META_BATCH_SIZE
190195 self .use_fixed_preprocessing_seed = use_fixed_preprocessing_seed
196+ self .validation_frequency = validation_frequency
191197
192198 if self .use_fixed_preprocessing_seed and not (
193199 self .n_estimators_finetune
@@ -528,16 +534,26 @@ def _fit( # noqa: C901,PLR0912
528534 use_amp = self .device .startswith ("cuda" ) and torch .cuda .is_available ()
529535 scaler = GradScaler () if use_amp else None # type: ignore
530536
531- logger .info ("--- 🚀 Eval default model ---" )
532- eval_result = self ._evaluate_model (
533- validation_eval_config ,
534- X_train , # pyright: ignore[reportArgumentType]
535- y_train , # pyright: ignore[reportArgumentType]
536- X_val , # pyright: ignore[reportArgumentType]
537- y_val , # pyright: ignore[reportArgumentType]
538- )
539- self ._log_epoch_evaluation (- 1 , eval_result , mean_train_loss = None )
540- best_metric : float = eval_result .primary
537+ if self .validation_frequency is not None :
538+ logger .info ("--- 🚀 Eval default model ---" )
539+ eval_result = self ._evaluate_model (
540+ validation_eval_config ,
541+ X_train , # pyright: ignore[reportArgumentType]
542+ y_train , # pyright: ignore[reportArgumentType]
543+ X_val , # pyright: ignore[reportArgumentType]
544+ y_val , # pyright: ignore[reportArgumentType]
545+ )
546+ self ._log_epoch_evaluation (- 1 , eval_result , mean_train_loss = None )
547+ best_metric : float = eval_result .primary
548+ else :
549+ if self .early_stopping :
550+ warnings .warn (
551+ "`early_stopping` is enabled but `validation_frequency` is None. "
552+ "Early stopping requires validation; it will be disabled." ,
553+ UserWarning ,
554+ stacklevel = 2 ,
555+ )
556+ best_metric = self ._get_initial_best_metric ()
541557
542558 static_seed , rng = infer_random_state (self .random_state )
543559 preprocessing_random_state = (
@@ -684,61 +700,67 @@ def _fit( # noqa: C901,PLR0912
684700 epoch_loss_sum / epoch_batches if epoch_batches > 0 else None
685701 )
686702
687- eval_result = self ._evaluate_model (
688- validation_eval_config ,
689- X_train , # pyright: ignore[reportArgumentType]
690- y_train , # pyright: ignore[reportArgumentType]
691- X_val , # pyright: ignore[reportArgumentType]
692- y_val , # pyright: ignore[reportArgumentType]
703+ run_validation = (
704+ self .validation_frequency is not None
705+ and (epoch + 1 ) % self .validation_frequency == 0
693706 )
694707
695- self ._log_epoch_evaluation (epoch , eval_result , mean_train_loss )
708+ if run_validation :
709+ eval_result = self ._evaluate_model (
710+ validation_eval_config ,
711+ X_train , # pyright: ignore[reportArgumentType]
712+ y_train , # pyright: ignore[reportArgumentType]
713+ X_val , # pyright: ignore[reportArgumentType]
714+ y_val , # pyright: ignore[reportArgumentType]
715+ )
696716
697- primary_metric = eval_result . primary
717+ self . _log_epoch_evaluation ( epoch , eval_result , mean_train_loss )
698718
699- if output_dir is not None and not np .isnan (primary_metric ):
700- save_interval_checkpoint = (
701- self .save_checkpoint_interval is not None
702- and (epoch + 1 ) % self .save_checkpoint_interval == 0
703- )
719+ primary_metric = eval_result .primary
704720
705- is_best = self ._is_improvement (primary_metric , best_metric )
706-
707- if save_interval_checkpoint or is_best :
708- save_checkpoint (
709- estimator = self .finetuned_estimator_ ,
710- output_dir = output_dir ,
711- epoch = epoch + 1 ,
712- optimizer = optimizer ,
713- metrics = self ._get_checkpoint_metrics (eval_result ),
714- train_size = train_size ,
715- is_best = is_best ,
716- save_interval_checkpoint = save_interval_checkpoint ,
721+ if output_dir is not None and not np .isnan (primary_metric ):
722+ save_interval_checkpoint = (
723+ self .save_checkpoint_interval is not None
724+ and (epoch + 1 ) % self .save_checkpoint_interval == 0
717725 )
718726
719- if self . early_stopping and not np . isnan (primary_metric ):
720- if self . _is_improvement ( primary_metric , best_metric ):
721- best_metric = primary_metric
722- patience_counter = 0
723- best_model = copy . deepcopy ( self .finetuned_estimator_ )
724- else :
725- patience_counter += 1
726- logger . info (
727- "⚠️ No improvement for %s epochs. Best %s: %.4f" ,
728- patience_counter ,
729- self . _metric_name ,
730- best_metric ,
731- )
727+ is_best = self . _is_improvement (primary_metric , best_metric )
728+
729+ if save_interval_checkpoint or is_best :
730+ save_checkpoint (
731+ estimator = self .finetuned_estimator_ ,
732+ output_dir = output_dir ,
733+ epoch = epoch + 1 ,
734+ optimizer = optimizer ,
735+ metrics = self . _get_checkpoint_metrics ( eval_result ) ,
736+ train_size = train_size ,
737+ is_best = is_best ,
738+ save_interval_checkpoint = save_interval_checkpoint ,
739+ )
732740
733- if patience_counter >= self .early_stopping_patience :
734- logger .info (
735- "🛑 Early stopping triggered. Best %s: %.4f" ,
736- self ._metric_name ,
737- best_metric ,
738- )
739- if best_model is not None :
740- self .finetuned_estimator_ = best_model
741- break
741+ if self .early_stopping and not np .isnan (primary_metric ):
742+ if self ._is_improvement (primary_metric , best_metric ):
743+ best_metric = primary_metric
744+ patience_counter = 0
745+ best_model = copy .deepcopy (self .finetuned_estimator_ )
746+ else :
747+ patience_counter += 1
748+ logger .info (
749+ "⚠️ No improvement for %s epochs. Best %s: %.4f" ,
750+ patience_counter ,
751+ self ._metric_name ,
752+ best_metric ,
753+ )
754+
755+ if patience_counter >= self .early_stopping_patience :
756+ logger .info (
757+ "🛑 Early stopping triggered. Best %s: %.4f" ,
758+ self ._metric_name ,
759+ best_metric ,
760+ )
761+ if best_model is not None :
762+ self .finetuned_estimator_ = best_model
763+ break
742764
743765 if self .time_limit is not None :
744766 elapsed_time = time .monotonic () - start_time
0 commit comments