1+ import logging
12import os
23from pathlib import Path
34from typing import Annotated , Any , Optional
2728from modalities .util import warn_rank_0
2829from modalities .utils .profilers .profilers import SteppableNoProfiler
2930
31+ logger = logging .getLogger (__name__ )
32+
3033
3134class CudaEnvSettings (BaseModel ):
3235 local_rank : Annotated [int , Field (strict = True , ge = 0 )]
@@ -46,6 +49,7 @@ class ConsistencyEnforcement(BaseModel):
4649 enforce_last_step_logged : bool = True
4750 enforce_last_step_evaluated : bool = True
4851 enforce_last_step_checkpointed : bool = True
52+ enforce_enough_tokens_in_dataset : bool = True
4953
5054
5155class Intervals (BaseModel ):
@@ -192,15 +196,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
192196
193197 @model_validator (mode = "after" )
194198 def _check_token_amount_in_dataset (self ) -> "TrainingComponentsInstantiationModel" :
195- if (
196- len (self .train_dataset ) * self .settings .step_profile .sequence_length
197- < self .settings .training_target .num_target_tokens
198- ):
199- raise ValueError (
200- "Not enough tokens in the dataset. "
201- f"Actual: { len (self .train_dataset ) * self .settings .step_profile .sequence_length } , "
202- f"Expected: >={ self .settings .training_target .num_target_tokens } "
203- )
199+ dataset_tokens = len (self .train_dataset ) * self .settings .step_profile .sequence_length
200+ expected_tokens = self .settings .training_target .num_target_tokens
201+ if dataset_tokens < expected_tokens :
202+ msg = f"Not enough tokens in dataset. Actual: { dataset_tokens } , Expected: >={ expected_tokens } "
203+ if self .settings .consistency_enforcement .enforce_enough_tokens_in_dataset :
204+ raise ValueError (msg )
205+ else :
206+ logger .warning (msg )
204207 return self
205208
206209
0 commit comments