Skip to content

Commit d5032d0

Browse files
feat: add gpu-aware keras/pytorch training runtime (#183)
* feat: add gpu-aware keras/pytorch training runtime * fix: align launcher and epoch defaults * fix: correct multiclass label handling in streaming trainers * fix: add macos spawn fallback for pytorch dataloader workers * chore: bump version to 1.3.7 * chore: bump version to 1.4.0 * fix: tune keras early stopping and harden task fallback
1 parent b021b02 commit d5032d0

28 files changed

Lines changed: 1606 additions & 294 deletions

Dockerfile

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,47 @@
33
# base → shared dependencies (no Spark provider)
44
# pyspark → local PySpark execution (DEFAULT)
55
# databricks→ remote Databricks Connect execution
6+
# VARIANT=cpu (default) or VARIANT=gpu
67
#
78
# Usage:
89
# docker build . # default: pyspark
910
# docker build --target databricks . # databricks-connect
11+
# docker build --build-arg VARIANT=gpu . # GPU-enabled PySpark image (amd64 only)
12+
13+
# ============================================
14+
# Stage: base selection (cpu/gpu)
15+
# ============================================
16+
ARG PYTHON_VERSION=3.12
17+
ARG VARIANT=cpu
18+
19+
FROM python:${PYTHON_VERSION}-slim-bookworm AS base-cpu
20+
21+
FROM nvidia/cuda:12.9.0-runtime-ubuntu24.04 AS base-gpu
22+
ARG PYTHON_VERSION=3.12
23+
RUN apt-get update && apt-get install -y --no-install-recommends \
24+
software-properties-common \
25+
&& add-apt-repository ppa:deadsnakes/ppa \
26+
&& apt-get update && apt-get install -y --no-install-recommends \
27+
python${PYTHON_VERSION} \
28+
python${PYTHON_VERSION}-venv \
29+
python${PYTHON_VERSION}-dev \
30+
python3-pip \
31+
&& ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 \
32+
&& ln -sf /usr/bin/python3 /usr/bin/python \
33+
&& rm -rf /var/lib/apt/lists/* \
34+
&& rm -f /usr/lib/python${PYTHON_VERSION}/EXTERNALLY-MANAGED
1035

1136
# ============================================
1237
# Stage: base (shared across all variants)
1338
# ============================================
39+
FROM base-${VARIANT} AS base
40+
ARG TARGETARCH
41+
ARG VARIANT=cpu
1442
ARG PYTHON_VERSION=3.12
15-
FROM python:${PYTHON_VERSION}-slim-bookworm AS base
1643

44+
# System dependencies
1745
WORKDIR /code
1846

19-
# System dependencies
2047
RUN apt-get update && apt-get install -y \
2148
build-essential \
2249
gcc \
@@ -29,14 +56,17 @@ RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
2956
ENV PATH="/root/.cargo/bin:${PATH}"
3057

3158
# Python tooling
32-
RUN pip install --no-cache-dir --upgrade pip && \
59+
RUN rm -rf /usr/lib/python3/dist-packages/*.dist-info 2>/dev/null; \
60+
pip install --no-cache-dir pip && \
3361
pip install --no-cache-dir poetry && \
3462
poetry config virtualenvs.create false
3563

3664
# Install large stable dependencies before poetry to maximize build cache reuse.
37-
# INSTALL_PYTORCH controls whether CPU-only PyTorch is installed.
65+
# INSTALL_PYTORCH controls whether PyTorch is installed.
3866
ARG INSTALL_PYTORCH="true"
39-
RUN if [ "$INSTALL_PYTORCH" = "true" ]; then \
67+
RUN if [ "$VARIANT" = "gpu" ] && [ "$INSTALL_PYTORCH" = "true" ]; then \
68+
pip install --no-cache-dir torch==2.7.1; \
69+
elif [ "$INSTALL_PYTORCH" = "true" ]; then \
4070
pip install --no-cache-dir torch==2.7.1 \
4171
--index-url https://download.pytorch.org/whl/cpu \
4272
--extra-index-url https://pypi.org/simple; \
@@ -101,6 +131,10 @@ RUN mkdir -p /opt/spark-jars && \
101131

102132
# Spark configuration for local mode
103133
ARG PYTHON_VERSION=3.12
134+
# GPU variant (Ubuntu) may install to dist-packages. Symlink ensures stable SPARK_HOME.
135+
RUN mkdir -p /usr/local/lib/python${PYTHON_VERSION}/site-packages && \
136+
ln -sf $(python3 -c "import pyspark; print(pyspark.__path__[0])") \
137+
/usr/local/lib/python${PYTHON_VERSION}/site-packages/pyspark 2>/dev/null || true
104138
ENV SPARK_HOME="/usr/local/lib/python${PYTHON_VERSION}/site-packages/pyspark"
105139
ENV PYSPARK_PYTHON="python3"
106140
ENV PYSPARK_DRIVER_PYTHON="python3"

Makefile

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ help:
4949
@echo ""
5050
@echo "🏗️ Building:"
5151
@echo " make build Build default image (PySpark)"
52+
@echo " make build-gpu Build GPU variant (CUDA + GPU PyTorch, amd64)"
5253
@echo " make build-databricks Build Databricks variant"
5354
@echo ""
5455
@echo "🧹 Cleanup:"
@@ -70,21 +71,23 @@ help:
7071
.PHONY: test-integration
7172
test-integration:
7273
@echo "🧪 Running staged pytest integration suite..."
74+
@echo "Using DATALOADER_WORKERS=$${DATALOADER_WORKERS:-0}"
7375
@if [ -n "$(INTEGRATION_RUN_ID)" ]; then \
7476
echo "Using integration run id: $(INTEGRATION_RUN_ID)"; \
75-
PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" bash scripts/tests/run_integration_staged.sh; \
77+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" bash scripts/tests/run_integration_staged.sh; \
7678
else \
77-
bash scripts/tests/run_integration_staged.sh; \
79+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" bash scripts/tests/run_integration_staged.sh; \
7880
fi
7981

8082
.PHONY: test-integration-verbose
8183
test-integration-verbose:
8284
@echo "🧪 Running staged pytest integration suite (verbose)..."
85+
@echo "Using DATALOADER_WORKERS=$${DATALOADER_WORKERS:-0}"
8386
@if [ -n "$(INTEGRATION_RUN_ID)" ]; then \
8487
echo "Using integration run id: $(INTEGRATION_RUN_ID)"; \
85-
PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
88+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_RUN_ID="$(INTEGRATION_RUN_ID)" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
8689
else \
87-
PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
90+
DATALOADER_WORKERS="$${DATALOADER_WORKERS:-0}" PLEXE_IT_VERBOSE=1 bash scripts/tests/run_integration_staged.sh; \
8891
fi
8992

9093
# Fast sanity check - 1 iteration, minimal config
@@ -368,6 +371,17 @@ build:
368371
-f Dockerfile .
369372
@echo "✅ Build complete: plexe:py$(PYTHON_VERSION)"
370373

374+
# Build GPU variant (NVIDIA CUDA + CUDA-enabled PyTorch, amd64 only)
375+
.PHONY: build-gpu
376+
build-gpu:
377+
@echo "🏗️ Building GPU variant (Python $(PYTHON_VERSION), CUDA)..."
378+
docker buildx build --platform linux/amd64 --output type=docker --provenance=false \
379+
--build-arg PYTHON_VERSION=$(PYTHON_VERSION) \
380+
--build-arg VARIANT=gpu \
381+
-t plexe:py$(PYTHON_VERSION)-gpu \
382+
-f Dockerfile .
383+
@echo "✅ Build complete: plexe:py$(PYTHON_VERSION)-gpu"
384+
371385

372386
# Build Databricks variant
373387
.PHONY: build-databricks

config.yaml.template

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
# Default epochs for neural network training (Keras, PyTorch)
3636
# Type: integer
37-
# Default: 25
38-
# nn_default_epochs: 25
37+
# Default: 10
38+
# nn_default_epochs: 10
3939

4040
# Maximum epochs for neural network training (Keras, PyTorch)
4141
# Type: integer

plexe/CODE_INDEX.md

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Code Index: plexe
22

3-
> Generated on 2026-03-02 19:57:53
3+
> Generated on 2026-03-02 22:03:39
44
55
Code structure and public interface documentation for the **plexe** package.
66

@@ -207,14 +207,14 @@ Local process runner - executes training in subprocess.
207207

208208
**`LocalProcessRunner`** - Runs training in local subprocess.
209209
- `__init__(self, work_dir: str)`
210-
- `run_training(self, template: str, model: Any, feature_pipeline: Pipeline, train_uri: str, val_uri: str, timeout: int, target_columns: list[str], optimizer: Any, loss: Any, epochs: int, batch_size: int, group_column: str | None) -> Path` - Execute training in subprocess.
210+
- `run_training(self, template: str, model: Any, feature_pipeline: Pipeline, train_uri: str, val_uri: str, timeout: int, target_columns: list[str], task_type: str, optimizer: Any, loss: Any, epochs: int, batch_size: int, group_column: str | None, mixed_precision: bool, dataloader_workers: int) -> Path` - Execute training in subprocess.
211211

212212
---
213213
## `execution/training/runner.py`
214214
Training runner abstract base class.
215215

216216
**`TrainingRunner`** - Abstract base class for training execution environments.
217-
- `run_training(self, template: str, model: Any, feature_pipeline: Pipeline, train_uri: str, val_uri: str, timeout: int, target_columns: list[str]) -> Path` - Execute model training and return path to artifacts.
217+
- `run_training(self, template: str, model: Any, feature_pipeline: Pipeline, train_uri: str, val_uri: str, timeout: int, target_columns: list[str], task_type: str) -> Path` - Execute model training and return path to artifacts.
218218

219219
---
220220
## `helpers.py`
@@ -312,6 +312,9 @@ Simple dataclasses for model building workflow.
312312

313313
**`DataLayout`** - Physical structure of dataset (not semantic meaning).
314314

315+
**`TaskType`** - Canonical ML task type determined during Phase 1.
316+
- `is_classification(self) -> bool` - No description
317+
315318
**`Metric`** - Evaluation metric definition.
316319

317320
**`BuildContext`** - Context passed through workflow phases.
@@ -452,6 +455,7 @@ Standard Keras predictor - NO Plexe dependencies.
452455
**`KerasPredictor`** - Standalone Keras predictor.
453456
- `__init__(self, model_dir: str)`
454457
- `predict(self, x: pd.DataFrame) -> pd.DataFrame` - Make predictions on input DataFrame.
458+
- `predict_proba(self, x: pd.DataFrame) -> pd.DataFrame` - Predict per-class probabilities on input DataFrame.
455459

456460
---
457461
## `templates/inference/lightgbm_predictor.py`
@@ -468,6 +472,7 @@ Standard PyTorch predictor - NO Plexe dependencies.
468472
**`PyTorchPredictor`** - Standalone PyTorch predictor.
469473
- `__init__(self, model_dir: str)`
470474
- `predict(self, x: pd.DataFrame) -> pd.DataFrame` - Make predictions on input DataFrame.
475+
- `predict_proba(self, x: pd.DataFrame) -> pd.DataFrame` - Predict per-class probabilities on input DataFrame.
471476

472477
---
473478
## `templates/inference/xgboost_predictor.py`
@@ -489,37 +494,37 @@ Model card template generator.
489494
Hardcoded robust CatBoost training loop.
490495

491496
**Functions:**
492-
- `train_catboost(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str) -> dict` - Train CatBoost model directly (no Spark).
497+
- `train_catboost(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, task_type: str | None) -> dict` - Train CatBoost model directly (no Spark).
493498
- `main()` - No description
494499

495500
---
496501
## `templates/training/train_keras.py`
497-
Hardcoded robust Keras training loop.
502+
Keras training template with streaming data loading, multi-GPU (MirroredStrategy), and mixed precision.
498503

499504
**Functions:**
500-
- `train_keras(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, epochs: int, batch_size: int) -> dict` - Train Keras model directly.
505+
- `train_keras(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, epochs: int, batch_size: int, use_multi_gpu: bool, use_mixed_precision: bool, task_type: str | None) -> dict` - Train Keras model with streaming data, optional multi-GPU, and mixed precision.
501506

502507
---
503508
## `templates/training/train_lightgbm.py`
504509
Hardcoded robust LightGBM training loop.
505510

506511
**Functions:**
507-
- `train_lightgbm(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, group_column: str | None) -> dict` - Train LightGBM model directly (no Spark).
512+
- `train_lightgbm(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, group_column: str | None, task_type: str | None) -> dict` - Train LightGBM model directly (no Spark).
508513
- `main()` - No description
509514

510515
---
511516
## `templates/training/train_pytorch.py`
512-
Hardcoded robust PyTorch training loop.
517+
PyTorch training template with streaming data loading, multi-GPU (DDP), and mixed precision.
513518

514519
**Functions:**
515-
- `train_pytorch(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, epochs: int, batch_size: int) -> dict` - Train PyTorch model directly.
520+
- `train_pytorch(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, epochs: int, batch_size: int, num_workers: int, use_ddp: bool, use_mixed_precision: bool, task_type: str | None) -> dict` - Train PyTorch model with streaming data, optional DDP, and mixed precision.
516521

517522
---
518523
## `templates/training/train_xgboost.py`
519524
Hardcoded robust XGBoost training loop.
520525

521526
**Functions:**
522-
- `train_xgboost(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, group_column: str | None) -> dict` - Train XGBoost model directly (no Spark).
527+
- `train_xgboost(untrained_model_path: Path, train_uri: str, val_uri: str, output_dir: Path, target_column: str, group_column: str | None, task_type: str | None) -> dict` - Train XGBoost model directly (no Spark).
523528
- `main()` - No description
524529

525530
---
@@ -624,7 +629,7 @@ Utility functions for dashboard data loading.
624629
- `load_report(exp_path: Path, report_name: str) -> dict | None` - Load YAML report from DirNames.BUILD_DIR/reports/.
625630
- `load_code_file(file_path: Path) -> str | None` - Load Python code file.
626631
- `load_parquet_sample(uri: str, limit: int) -> pd.DataFrame | None` - Load first N rows from parquet file.
627-
- `get_parquet_row_count(uri: str) -> int | None` - Get row count from parquet file.
632+
- `get_parquet_row_count(uri: str) -> int | None` - Get row count from parquet metadata without reading data.
628633
- `load_json_file(file_path: Path) -> dict | None` - Load JSON file.
629634

630635
---
@@ -636,6 +641,21 @@ LiteLLM model wrapper with retry logic and optional post-call hook.
636641
- `generate(self)` - Generate with automatic retries, header injection, and post-call hook.
637642
- `chat(self)` - Chat with automatic retries, header injection, and post-call hook.
638643

644+
---
645+
## `utils/parquet_dataset.py`
646+
Streaming parquet data loading utilities for large-dataset training.
647+
648+
**`ParquetIterableDataset`** - Streaming parquet dataset for PyTorch DataLoader.
649+
- `__init__(self, uri: str, target_column: str, task_type: str)`
650+
- `total_rows(self) -> int` - No description
651+
652+
**Functions:**
653+
- `get_parquet_row_count(uri: str) -> int` - Get total row count from parquet metadata without reading data.
654+
- `get_dataset_size_bytes(uri: str) -> int` - Get dataset size in bytes for a local file or directory of parquet files.
655+
- `parquet_batch_generator(uri: str, target_column: str, batch_size: int, task_type: str | None) -> Iterator[tuple[np.ndarray, np.ndarray]]` - Streaming parquet batch generator for Keras/TensorFlow.
656+
- `get_parquet_feature_count(uri: str, target_column: str) -> int` - Get number of feature columns (total columns minus target).
657+
- `get_steps_per_epoch(uri: str, batch_size: int) -> int` - Compute number of steps per epoch for a parquet dataset.
658+
639659
---
640660
## `utils/reporting.py`
641661
Utilities for saving agent reports to disk.

plexe/config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,21 @@ class Config(BaseSettings):
309309
# Training settings
310310
training_timeout: int = Field(default=1800, description="Timeout for training runs (seconds)", gt=0)
311311
nn_default_epochs: int = Field(
312-
default=25, description="Default epochs for neural network training (Keras, PyTorch)"
312+
default=10, description="Default epochs for neural network training (Keras, PyTorch)"
313313
)
314314
nn_max_epochs: int = Field(default=50, description="Maximum epochs for neural network training (Keras, PyTorch)")
315315
nn_default_batch_size: int = Field(
316316
default=32, description="Default batch size for neural network training (Keras, PyTorch)"
317317
)
318+
nn_training_timeout: int = Field(
319+
default=14400, description="Timeout for neural network training on full dataset (seconds)", gt=0
320+
)
321+
mixed_precision: bool = Field(
322+
default=True, description="Use mixed precision (FP16) when GPU available (auto-disabled on CPU)"
323+
)
324+
dataloader_workers: int = Field(
325+
default=4, description="Number of DataLoader worker processes for streaming data loading", ge=0
326+
)
318327

319328
# LLM settings (per agent role)
320329
statistical_analysis_llm: str = Field(

0 commit comments

Comments
 (0)