Skip to content

Commit 4aa0d89

Browse files
fix: add macos spawn fallback for pytorch dataloader workers
1 parent 1b7f611 commit 4aa0d89

5 files changed

Lines changed: 85 additions & 8 deletions

File tree

Makefile

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,23 @@ help:
7171
.PHONY: test-integration
7272
test-integration:
7373
@echo "🧪 Running staged pytest integration suite..."
74+
@echo "Using DATALOADER_WORKERS=$${DATALOADER_WORKERS:-0}"
7475
@if [ -n "$(INTEGRATION_RUN_ID)" ]; then \
7576
echo "Using integration run id: $(INTEGRATION_RUN_ID)"; \
76-
PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" bash scripts/tests/run_integration_staged.sh; \
77+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" bash scripts/tests/run_integration_staged.sh; \
7778
else \
78-
bash scripts/tests/run_integration_staged.sh; \
79+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" bash scripts/tests/run_integration_staged.sh; \
7980
fi
8081

8182
.PHONY: test-integration-verbose
8283
test-integration-verbose:
8384
@echo "🧪 Running staged pytest integration suite (verbose)..."
85+
@echo "Using DATALOADER_WORKERS=$${DATALOADER_WORKERS:-0}"
8486
@if [ -n "$(INTEGRATION_RUN_ID)" ]; then \
8587
echo "Using integration run id: $(INTEGRATION_RUN_ID)"; \
86-
PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
88+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
8789
else \
88-
PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
90+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
8991
fi
9092

9193
# Fast sanity check - 1 iteration, minimal config

plexe/CODE_INDEX.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Code Index: plexe
22

3-
> Generated on 2026-03-02 21:24:57
3+
> Generated on 2026-03-02 21:25:06
44
55
Code structure and public interface documentation for the **plexe** package.
66

plexe/templates/training/train_pytorch.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import inspect
1313
import json
1414
import logging
15+
import multiprocessing as mp
1516
import os
1617
import sys
1718
from pathlib import Path
@@ -48,6 +49,27 @@ def _is_rank0(use_ddp: bool) -> bool:
4849
return dist.get_rank() == 0
4950

5051

52+
def _resolve_num_workers(requested_workers: int) -> int:
53+
"""Resolve safe DataLoader worker count for the current runtime."""
54+
if requested_workers <= 0:
55+
return 0
56+
57+
start_method = mp.get_start_method(allow_none=True)
58+
if start_method is None:
59+
start_method = mp.get_context().get_start_method()
60+
61+
if sys.platform == "darwin" and start_method == "spawn":
62+
logger.warning(
63+
"Falling back DataLoader workers from %s to 0 on platform=%s start_method=%s",
64+
requested_workers,
65+
sys.platform,
66+
start_method,
67+
)
68+
return 0
69+
70+
return requested_workers
71+
72+
5173
def train_pytorch(
5274
untrained_model_path: Path,
5375
train_uri: str,
@@ -151,16 +173,18 @@ def train_pytorch(
151173
train_dataset = ParquetIterableDataset(train_uri, target_column, task_type)
152174
val_dataset = ParquetIterableDataset(val_uri, target_column, task_type)
153175

176+
effective_num_workers = _resolve_num_workers(num_workers)
177+
154178
train_loader = torch.utils.data.DataLoader(
155179
train_dataset,
156180
batch_size=batch_size,
157-
num_workers=num_workers,
181+
num_workers=effective_num_workers,
158182
pin_memory=device.type == "cuda",
159183
)
160184
val_loader = torch.utils.data.DataLoader(
161185
val_dataset,
162186
batch_size=batch_size,
163-
num_workers=num_workers,
187+
num_workers=effective_num_workers,
164188
pin_memory=device.type == "cuda",
165189
)
166190

@@ -172,6 +196,7 @@ def train_pytorch(
172196
logger.info("Using ParquetIterableDataset for streaming data loading")
173197
logger.info(f"Training data: {train_rows} rows, {n_features} features (streaming)")
174198
logger.info(f"Validation data: {val_rows} rows (streaming)")
199+
logger.info(f"DataLoader workers: requested={num_workers}, effective={effective_num_workers}")
175200

176201
# ============================================
177202
# Step 6: Setup mixed precision

tests/CODE_INDEX.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Code Index: tests
22

3-
> Generated on 2026-03-02 21:24:57
3+
> Generated on 2026-03-02 21:25:06
44
55
Test suite structure and test case documentation.
66

@@ -147,6 +147,16 @@ Unit tests for pipeline_runner feature name resolution.
147147
- `test_resolve_feature_names_falls_back_on_mismatch()` - Returns generic names when resolved names don't match output count.
148148
- `test_resolve_feature_names_falls_back_when_unavailable()` - Returns generic names when no get_feature_names_out is available.
149149

150+
---
151+
## `unit/templates/training/test_train_pytorch_worker_fallback.py`
152+
Unit tests for PyTorch DataLoader worker fallback behavior.
153+
154+
**Functions:**
155+
- `test_resolve_num_workers_zero_is_unchanged() -> None` - Requested zero workers should remain zero.
156+
- `test_resolve_num_workers_falls_back_on_darwin_spawn(monkeypatch) -> None` - On macOS spawn, requested workers should fall back to zero.
157+
- `test_resolve_num_workers_uses_context_when_start_method_is_none(monkeypatch) -> None` - When get_start_method returns None, context start method should be used.
158+
- `test_resolve_num_workers_kept_on_non_darwin_spawn(monkeypatch) -> None` - Spawn on non-macOS should keep the requested worker count.
159+
150160
---
151161
## `unit/test_config.py`
152162
Unit tests for config helpers.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Unit tests for PyTorch DataLoader worker fallback behavior."""
2+
3+
import pytest
4+
5+
pytest.importorskip("torch")
6+
7+
from plexe.templates.training import train_pytorch
8+
9+
10+
def test_resolve_num_workers_zero_is_unchanged() -> None:
11+
"""Requested zero workers should remain zero."""
12+
assert train_pytorch._resolve_num_workers(0) == 0
13+
14+
15+
def test_resolve_num_workers_falls_back_on_darwin_spawn(monkeypatch) -> None:
16+
"""On macOS spawn, requested workers should fall back to zero."""
17+
monkeypatch.setattr(train_pytorch.sys, "platform", "darwin")
18+
monkeypatch.setattr(train_pytorch.mp, "get_start_method", lambda allow_none=True: "spawn")
19+
assert train_pytorch._resolve_num_workers(4) == 0
20+
21+
22+
def test_resolve_num_workers_uses_context_when_start_method_is_none(monkeypatch) -> None:
23+
"""When get_start_method returns None, context start method should be used."""
24+
25+
class _Context:
26+
@staticmethod
27+
def get_start_method() -> str:
28+
return "spawn"
29+
30+
monkeypatch.setattr(train_pytorch.sys, "platform", "darwin")
31+
monkeypatch.setattr(train_pytorch.mp, "get_start_method", lambda allow_none=True: None)
32+
monkeypatch.setattr(train_pytorch.mp, "get_context", lambda: _Context())
33+
assert train_pytorch._resolve_num_workers(2) == 0
34+
35+
36+
def test_resolve_num_workers_kept_on_non_darwin_spawn(monkeypatch) -> None:
37+
"""Spawn on non-macOS should keep the requested worker count."""
38+
monkeypatch.setattr(train_pytorch.sys, "platform", "linux")
39+
monkeypatch.setattr(train_pytorch.mp, "get_start_method", lambda allow_none=True: "spawn")
40+
assert train_pytorch._resolve_num_workers(3) == 3

0 commit comments

Comments
 (0)