@@ -699,7 +699,8 @@ def create_data_loaders(config: TrainingConfiguration,
699699 data_root : Optional [str ] = None ,
700700 pin_memory : Optional [bool ] = None ,
701701 interactive : bool = True ,
702- no_download : bool = False ) -> Tuple [DataLoader , DataLoader , DataLoader ]:
702+ no_download : bool = False ,
703+ num_workers_override : Optional [int ] = None ) -> Tuple [DataLoader , DataLoader , DataLoader ]:
703704 """
704705 Create training, validation, and test data loaders
705706
@@ -709,12 +710,14 @@ def create_data_loaders(config: TrainingConfiguration,
709710 pin_memory: Override pin_memory setting (optional)
710711 interactive: If False, auto-confirm dataset download/prepare (--yes)
711712 no_download: If True, use only existing prepared data; fail if any missing (--no-download)
713+ num_workers_override: If set, override config.hardware.num_workers (e.g. 0 for ROCm)
712714
713715 Returns:
714716 Tuple of (train_loader, val_loader, test_loader)
715717 """
716718 # Use provided pin_memory override or fall back to config
717719 use_pin_memory = pin_memory if pin_memory is not None else config .hardware .pin_memory
720+ num_workers = num_workers_override if num_workers_override is not None else config .hardware .num_workers
718721
719722 # Create datasets
720723 train_datasets = []
@@ -756,7 +759,7 @@ def create_data_loaders(config: TrainingConfiguration,
756759 train_dataset ,
757760 batch_size = config .training .batch_size ,
758761 shuffle = True ,
759- num_workers = config . hardware . num_workers ,
762+ num_workers = num_workers ,
760763 pin_memory = use_pin_memory ,
761764 collate_fn = collate_audio_samples ,
762765 drop_last = True # Ensure consistent batch sizes
@@ -766,7 +769,7 @@ def create_data_loaders(config: TrainingConfiguration,
766769 val_dataset ,
767770 batch_size = config .training .batch_size ,
768771 shuffle = False ,
769- num_workers = config . hardware . num_workers ,
772+ num_workers = num_workers ,
770773 pin_memory = use_pin_memory ,
771774 collate_fn = collate_audio_samples ,
772775 drop_last = False
@@ -776,7 +779,7 @@ def create_data_loaders(config: TrainingConfiguration,
776779 test_dataset ,
777780 batch_size = config .training .batch_size ,
778781 shuffle = False ,
779- num_workers = config . hardware . num_workers ,
782+ num_workers = num_workers ,
780783 pin_memory = use_pin_memory ,
781784 collate_fn = collate_audio_samples ,
782785 drop_last = False
0 commit comments