Skip to content

Commit 9cc31e9

Browse files
committed
refactor(dockerfile): native per-version GPU base on nvidia/cuda
Replace the runpod/pytorch + side-by-side install hack with a native per-version GPU base built directly on nvidia/cuda. Each image variant has exactly one Python interpreter at /usr/local/bin/python (3.10 from upstream jammy, 3.11/3.12/3.13 from deadsnakes), with torch installed natively for that interpreter from the cu128 wheel index. Eliminates the ~7 GB cold-start tax on non-3.12 images and decouples flash-worker from runpod/pytorch's Python release cadence. Adding 3.13 (or future 3.14/3.15) is now a CI matrix entry, not an upstream wait. Refs AE-2827.
1 parent 736e0d6 commit 9cc31e9

1 file changed

Lines changed: 54 additions & 27 deletions

File tree

Dockerfile

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,24 @@
1010
FROM runpod/pytorch:1.0.3-cu1281-torch291-ubuntu2204
1111

1212
# Target Python version for the worker runtime.
13+
# Native per-version GPU base. One Python interpreter per image, installed
14+
# directly into /usr/local/bin/python. No side-by-side, no symlink dance,
15+
# no 7 GB cold-start tax.
16+
#
17+
# - nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04 provides the CUDA + cuDNN
18+
# runtime libraries needed by torch's cu128 wheels.
19+
# - On jammy (22.04), python3.10 ships from upstream Ubuntu (system Python);
20+
# python3.11/3.12/3.13 come from the deadsnakes PPA. The same apt-get
21+
# invocation below resolves both sources transparently.
22+
# - pip is bootstrapped via get-pip.py (urllib stdlib): the Ubuntu system
23+
# python3.10 has ensurepip disabled by Debian policy, and deadsnakes
24+
# interpreters do not ship pip by default. get-pip.py works for any
25+
# interpreter regardless of distro patching.
1326
ARG PYTHON_VERSION=3.12
1427
ARG TORCH_VERSION=2.9.1+cu128
1528
ARG TORCH_INDEX_URL=https://download.pytorch.org/whl/cu128
1629

17-
# Expose the target version to the running worker for startup validation.
18-
ENV FLASH_PYTHON_VERSION=${PYTHON_VERSION}
30+
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
1931

2032
# Validate the base image provides the requested interpreter and activate it.
2133
# For non-3.12 targets, install torch for the selected Python and repoint
@@ -42,49 +54,64 @@ RUN python${PYTHON_VERSION} --version \
4254
&& ln -sf "$(which python${PYTHON_VERSION})" /usr/local/bin/python \
4355
&& ln -sf "$(which python${PYTHON_VERSION})" /usr/local/bin/python3; \
4456
fi
57+
# Re-declare ARGs after FROM so they're visible in this build stage.
58+
ARG PYTHON_VERSION
59+
ARG TORCH_VERSION
60+
ARG TORCH_INDEX_URL
4561

46-
WORKDIR /app
47-
48-
# Prevent interactive prompts during package installation
62+
ENV FLASH_PYTHON_VERSION=${PYTHON_VERSION}
4963
ENV DEBIAN_FRONTEND=noninteractive
50-
# Set timezone to avoid tzdata prompts
5164
ENV TZ=Etc/UTC
52-
53-
# Enable HuggingFace transfer acceleration
5465
ENV HF_HUB_ENABLE_HF_TRANSFER=1
55-
# Relocate HuggingFace cache outside /root/.cache to exclude from volume sync
5666
ENV HF_HOME=/hf-cache
5767

58-
# Configure APT cache to persist under /root/.cache for volume sync
68+
# Install ONE Python natively. 3.10 from upstream Ubuntu (jammy ships it as
69+
# system Python); 3.11/3.12/3.13 from deadsnakes.
70+
RUN apt-get update \
71+
&& apt-get install -y --no-install-recommends \
72+
software-properties-common ca-certificates curl gnupg \
73+
&& add-apt-repository -y ppa:deadsnakes/ppa \
74+
&& apt-get update \
75+
&& apt-get install -y --no-install-recommends \
76+
python${PYTHON_VERSION} \
77+
python${PYTHON_VERSION}-venv \
78+
python${PYTHON_VERSION}-dev \
79+
git \
80+
&& ln -sf "$(which python${PYTHON_VERSION})" /usr/local/bin/python \
81+
&& ln -sf "$(which python${PYTHON_VERSION})" /usr/local/bin/python3 \
82+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
83+
84+
# Bootstrap pip via get-pip.py.
85+
RUN python -c "import urllib.request; urllib.request.urlretrieve('https://bootstrap.pypa.io/get-pip.py', '/tmp/get-pip.py')" \
86+
&& python /tmp/get-pip.py --no-cache-dir \
87+
&& rm -f /tmp/get-pip.py
88+
89+
# Install torch natively for the active interpreter.
90+
RUN python -m pip install --no-cache-dir \
91+
--index-url ${TORCH_INDEX_URL} \
92+
"torch==${TORCH_VERSION}"
93+
94+
WORKDIR /app
95+
96+
# Configure APT cache to persist under /root/.cache for volume sync.
5997
RUN mkdir -p /root/.cache/apt/archives/partial \
6098
&& echo 'Dir::Cache "/root/.cache/apt";' > /etc/apt/apt.conf.d/01cache
6199

62-
# Install system dependencies and uv
63-
# Note: build-essential not pre-installed to reduce image size (400MB savings)
64-
# Automatic detection will install it when needed (no manual action required)
65-
# Advanced: Users can pre-install via system_dependencies=["build-essential"]
66-
RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y --no-install-recommends \
67-
curl ca-certificates git \
68-
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
100+
# Install uv for downstream dependency installation.
101+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
69102
&& cp ~/.local/bin/uv /usr/local/bin/uv \
70-
&& chmod +x /usr/local/bin/uv \
71-
&& apt-get clean \
72-
&& rm -rf /var/lib/apt/lists/*
103+
&& chmod +x /usr/local/bin/uv
73104

74-
# Copy app code and install dependencies
75-
# Use --python to target the active interpreter (preserves torch in its site-packages)
105+
# Copy app code and install worker dependencies into the active interpreter.
76106
COPY README.md pyproject.toml uv.lock ./
77107
COPY src/ ./
78108
RUN uv export --format requirements-txt --no-dev --no-hashes > requirements.txt \
79109
&& uv pip install --python $(which python) --break-system-packages -r requirements.txt
80110

81-
# Install numpy for the active Python version.
82-
# The runpod/pytorch image ships torch but not numpy. Flash build excludes numpy
83-
# from tarballs (BASE_IMAGE_PACKAGES) to save tarball space (~30 MB), so numpy
84-
# must be provided here in the base image.
111+
# Install numpy for the active Python (excluded from flash tarballs).
85112
RUN python -m pip install --no-cache-dir numpy
86113

87-
# Verify torch, numpy, and the expected Python version are available.
114+
# Verify torch, numpy, and the expected interpreter are wired correctly.
88115
RUN python -c "import sys; actual = f'{sys.version_info.major}.{sys.version_info.minor}'; expected = '${PYTHON_VERSION}'; assert actual == expected, f'Expected Python {expected}, got {actual}'; print(f'Python {actual} OK')" \
89116
&& python -c "import torch; print(f'torch {torch.__version__} CUDA {torch.cuda.is_available()}')" \
90117
&& python -c "import numpy; print(f'numpy {numpy.__version__}')"

0 commit comments

Comments
 (0)