|
14 | 14 | from modalities.dataloader.create_packed_data import EmbeddedStreamData |
15 | 15 | from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader |
16 | 16 | from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper |
| 17 | +from modalities.utils.logger_utils import get_logger |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class Dataset(TorchdataSet): |
@@ -445,20 +446,28 @@ class CombinedDataset(Dataset): |
445 | 446 | In the Dataloader, a batch will still contain packed samples from different datasets. |
446 | 447 | """ |
447 | 448 |
|
448 | | - def __init__(self, datasets: list[Dataset]): |
| 449 | + def __init__(self, datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False): |
449 | 450 | """Initializes the CombinedDataset object, combining multiple datasets. |
450 | 451 |
|
451 | 452 | Args: |
452 | 453 | datasets (list[Dataset]): A list of datasets to combine. |
453 | 454 | """ |
| 455 | + self.log_chunk_switch = log_chunk_switch |
| 456 | + self.log_initial_pos = log_initial_pos |
454 | 457 | self.datasets = datasets |
455 | 458 | self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64) |
| 459 | + self.logger = get_logger(__name__) |
456 | 460 |
|
457 | 461 | def __len__(self) -> int: |
458 | 462 | return self.cumulative_sizes[-1] |
459 | 463 |
|
460 | 464 | def __getitem__(self, idx: int) -> dict: |
461 | 465 | dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right") |
462 | 466 | local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0) |
| 467 | + if self.log_chunk_switch and local_idx == 0: |
| 468 | + self.logger.info(f"global_index={idx} chunk index={dataset_idx}, local index={local_idx}") |
| 469 | + |
| 470 | + if self.log_initial_pos: |
| 471 | + self.logger.info(f"global_index={idx} chunk index={dataset_idx}, local index={local_idx}") |
463 | 472 |
|
464 | 473 | return self.datasets[dataset_idx][local_idx] |
0 commit comments