Skip to content

Commit e97578d

Browse files
committed
feat: introduced enforce_enough_tokens_in_dataset for enabling a check if the dataset provides enough tokens.
1 parent 14287e3 commit e97578d

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

src/modalities/config/instantiation_models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
from pathlib import Path
34
from typing import Annotated, Any, Optional
@@ -27,6 +28,8 @@
2728
from modalities.util import warn_rank_0
2829
from modalities.utils.profilers.profilers import SteppableNoProfiler
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
class CudaEnvSettings(BaseModel):
3235
local_rank: Annotated[int, Field(strict=True, ge=0)]
@@ -46,6 +49,7 @@ class ConsistencyEnforcement(BaseModel):
4649
enforce_last_step_logged: bool = True
4750
enforce_last_step_evaluated: bool = True
4851
enforce_last_step_checkpointed: bool = True
52+
enforce_enough_tokens_in_dataset: bool = True
4953

5054

5155
class Intervals(BaseModel):
@@ -192,15 +196,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
192196

193197
@model_validator(mode="after")
194198
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
195-
if (
196-
len(self.train_dataset) * self.settings.step_profile.sequence_length
197-
< self.settings.training_target.num_target_tokens
198-
):
199-
raise ValueError(
200-
"Not enough tokens in the dataset. "
201-
f"Actual: {len(self.train_dataset) * self.settings.step_profile.sequence_length}, "
202-
f"Expected: >={self.settings.training_target.num_target_tokens}"
203-
)
199+
dataset_tokens = len(self.train_dataset) * self.settings.step_profile.sequence_length
200+
expected_tokens = self.settings.training_target.num_target_tokens
201+
if dataset_tokens < expected_tokens:
202+
msg = f"Not enough tokens in dataset. Actual: {dataset_tokens}, Expected: >={expected_tokens}"
203+
if self.settings.consistency_enforcement.enforce_enough_tokens_in_dataset:
204+
raise ValueError(msg)
205+
else:
206+
logger.warning(msg)
204207
return self
205208

206209

0 commit comments

Comments
 (0)