Skip to content

Commit b1c0145

Browse files
committed
feat: added logging for chunk switches and initial positions in CombinedDataset
1 parent bf0e6d3 commit b1c0145

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

src/modalities/config/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,8 @@ class PackedMemMapDatasetMegatronConfig(BaseModel):
448448

449449
class CombinedDatasetConfig(BaseModel):
450450
datasets: list[PydanticDatasetIFType]
451+
log_chunk_switch: bool = False
452+
log_initial_pos: bool = False
451453

452454

453455
class BatchSamplerConfig(BaseModel):

src/modalities/dataloader/dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from modalities.dataloader.create_packed_data import EmbeddedStreamData
1515
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
1616
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
17+
from modalities.utils.logger_utils import get_logger
1718

1819

1920
class Dataset(TorchdataSet):
@@ -445,20 +446,28 @@ class CombinedDataset(Dataset):
445446
In the Dataloader, a batch will still contain packed samples from different datasets.
446447
"""
447448

448-
def __init__(self, datasets: list[Dataset]):
449+
def __init__(self, datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False):
449450
"""Initializes the CombinedDataset object, combining multiple datasets.
450451
451452
Args:
452453
datasets (list[Dataset]): A list of datasets to combine.
453454
"""
455+
self.log_chunk_switch = log_chunk_switch
456+
self.log_initial_pos = log_initial_pos
454457
self.datasets = datasets
455458
self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64)
459+
self.logger = get_logger(__name__)
456460

457461
def __len__(self) -> int:
458462
return self.cumulative_sizes[-1]
459463

460464
def __getitem__(self, idx: int) -> dict:
461465
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right")
462466
local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0)
467+
if self.log_chunk_switch and local_idx == 0:
468+
self.logger.info(f"global_index={idx} chunk index={dataset_idx}, local index={local_idx}")
469+
470+
if self.log_initial_pos:
471+
self.logger.info(f"global_index={idx} chunk index={dataset_idx}, local index={local_idx}")
463472

464473
return self.datasets[dataset_idx][local_idx]

0 commit comments

Comments
 (0)