Skip to content

Commit 2aed428

Browse files
author
Alex J Lennon
committed
Add --num-workers override to avoid DataLoader worker crashes (use 0 on ROCm)
Made-with: Cursor
1 parent 9109e96 commit 2aed428

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

training/modules/data_pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,8 @@ def create_data_loaders(config: TrainingConfiguration,
699699
data_root: Optional[str] = None,
700700
pin_memory: Optional[bool] = None,
701701
interactive: bool = True,
702-
no_download: bool = False) -> Tuple[DataLoader, DataLoader, DataLoader]:
702+
no_download: bool = False,
703+
num_workers_override: Optional[int] = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
703704
"""
704705
Create training, validation, and test data loaders
705706
@@ -709,12 +710,14 @@ def create_data_loaders(config: TrainingConfiguration,
709710
pin_memory: Override pin_memory setting (optional)
710711
interactive: If False, auto-confirm dataset download/prepare (--yes)
711712
no_download: If True, use only existing prepared data; fail if any missing (--no-download)
713+
num_workers_override: If set, override config.hardware.num_workers (e.g. 0 for ROCm)
712714
713715
Returns:
714716
Tuple of (train_loader, val_loader, test_loader)
715717
"""
716718
# Use provided pin_memory override or fall back to config
717719
use_pin_memory = pin_memory if pin_memory is not None else config.hardware.pin_memory
720+
num_workers = num_workers_override if num_workers_override is not None else config.hardware.num_workers
718721

719722
# Create datasets
720723
train_datasets = []
@@ -756,7 +759,7 @@ def create_data_loaders(config: TrainingConfiguration,
756759
train_dataset,
757760
batch_size=config.training.batch_size,
758761
shuffle=True,
759-
num_workers=config.hardware.num_workers,
762+
num_workers=num_workers,
760763
pin_memory=use_pin_memory,
761764
collate_fn=collate_audio_samples,
762765
drop_last=True # Ensure consistent batch sizes
@@ -766,7 +769,7 @@ def create_data_loaders(config: TrainingConfiguration,
766769
val_dataset,
767770
batch_size=config.training.batch_size,
768771
shuffle=False,
769-
num_workers=config.hardware.num_workers,
772+
num_workers=num_workers,
770773
pin_memory=use_pin_memory,
771774
collate_fn=collate_audio_samples,
772775
drop_last=False
@@ -776,7 +779,7 @@ def create_data_loaders(config: TrainingConfiguration,
776779
test_dataset,
777780
batch_size=config.training.batch_size,
778781
shuffle=False,
779-
num_workers=config.hardware.num_workers,
782+
num_workers=num_workers,
780783
pin_memory=use_pin_memory,
781784
collate_fn=collate_audio_samples,
782785
drop_last=False

training/train.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TCNTrainer:
4949
clear error handling and progress monitoring.
5050
"""
5151

52-
def __init__(self, config: TrainingConfiguration, resume_from: Optional[str] = None, data_root: Optional[str] = None, interactive: bool = True, no_download: bool = False):
52+
def __init__(self, config: TrainingConfiguration, resume_from: Optional[str] = None, data_root: Optional[str] = None, interactive: bool = True, no_download: bool = False, num_workers_override: Optional[int] = None):
5353
"""
5454
Initialize trainer with configuration
5555
@@ -59,6 +59,7 @@ def __init__(self, config: TrainingConfiguration, resume_from: Optional[str] = N
5959
data_root: Root directory for data (optional)
6060
interactive: If False, auto-confirm dataset download/prepare (--yes)
6161
no_download: If True, use only existing prepared data; fail if any dataset missing (--no-download)
62+
num_workers_override: If set, override config.hardware.num_workers for DataLoader (e.g. 0 for ROCm)
6263
"""
6364
self.config = config
6465

@@ -91,7 +92,13 @@ def __init__(self, config: TrainingConfiguration, resume_from: Optional[str] = N
9192
# Create data loaders
9293
self.logger.info("Setting up data pipeline...")
9394
self.train_loader, self.val_loader, self.test_loader = create_data_loaders(
94-
config, data_root=data_root, pin_memory=self.pin_memory, interactive=interactive, no_download=no_download)
95+
config,
96+
data_root=data_root,
97+
pin_memory=self.pin_memory,
98+
interactive=interactive,
99+
no_download=no_download,
100+
num_workers_override=num_workers_override,
101+
)
95102

96103
# Create model
97104
self.logger.info("Creating TCN model...")
@@ -632,6 +639,14 @@ def parse_arguments():
632639
help='Use only existing prepared data; exit with error if any requested dataset is missing or not prepared (no download or MFA)'
633640
)
634641

642+
parser.add_argument(
643+
'--num-workers',
644+
type=int,
645+
default=None,
646+
metavar='N',
647+
help='Override num_workers for DataLoader (e.g. 0 to avoid worker crashes on ROCm)'
648+
)
649+
635650
return parser.parse_args()
636651

637652

@@ -655,7 +670,14 @@ def main():
655670

656671
try:
657672
# Create trainer
658-
trainer = TCNTrainer(config, resume_from=args.resume, data_root=args.data_root, interactive=not args.yes, no_download=args.no_download)
673+
trainer = TCNTrainer(
674+
config,
675+
resume_from=args.resume,
676+
data_root=args.data_root,
677+
interactive=not args.yes,
678+
no_download=args.no_download,
679+
num_workers_override=args.num_workers,
680+
)
659681

660682
if args.test_only:
661683
# Test only mode

0 commit comments

Comments
 (0)