Skip to content

Commit 1daec88

Browse files
committed
feat: add process injection support via LiveServerlessMixin
Replace pre-built Docker images with runtime tarball injection. The LiveServerlessMixin now generates dockerArgs that download, extract, and bootstrap the flash-worker tarball at container start time. - Add injection.py with build_injection_cmd() for dockerArgs generation - Add base image constants (FLASH_GPU_BASE_IMAGE, FLASH_CPU_BASE_IMAGE) - Update LiveServerlessMixin to configure dockerArgs on templates - Add _default_base_image and _legacy_image properties to all Live* classes - Update tests for injection-based template configuration - Revert InjectableWorkerMixin rename back to LiveServerlessMixin
1 parent 269181d commit 1daec88

8 files changed

Lines changed: 299 additions & 145 deletions

File tree

src/runpod_flash/core/resources/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def get_image_name(
142142
f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}",
143143
)
144144

145+
# Base images for process injection (no flash-worker baked in)
146+
FLASH_GPU_BASE_IMAGE = os.environ.get(
147+
"FLASH_GPU_BASE_IMAGE", "pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime"
148+
)
149+
FLASH_CPU_BASE_IMAGE = os.environ.get("FLASH_CPU_BASE_IMAGE", "python:3.11-slim")
150+
151+
# Worker tarball for process injection
152+
FLASH_WORKER_VERSION = os.environ.get("FLASH_WORKER_VERSION", "1.1.1")
153+
FLASH_WORKER_TARBALL_URL_TEMPLATE = os.environ.get(
154+
"FLASH_WORKER_TARBALL_URL",
155+
"https://github.com/runpod/flash-worker/releases/download/"
156+
"v{version}/flash-worker-v{version}-py3.11-linux-x86_64.tar.gz",
157+
)
158+
145159
# Worker configuration defaults
146160
DEFAULT_WORKERS_MIN = 0
147161
DEFAULT_WORKERS_MAX = 1
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Process injection utilities for flash-worker tarball delivery."""
2+
3+
from .constants import FLASH_WORKER_TARBALL_URL_TEMPLATE, FLASH_WORKER_VERSION
4+
5+
6+
def build_injection_cmd(
7+
worker_version: str = FLASH_WORKER_VERSION,
8+
tarball_url: str | None = None,
9+
) -> str:
10+
"""Build the dockerArgs command that downloads, extracts, and runs flash-worker.
11+
12+
Supports remote URLs (curl/wget) and local file paths (file://) for testing.
13+
Includes version-based caching to skip re-extraction on warm workers.
14+
Network volume caching stores extracted tarball at /runpod-volume/.flash-worker/v{version}.
15+
"""
16+
if tarball_url is None:
17+
tarball_url = FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=worker_version)
18+
19+
if tarball_url.startswith("file://"):
20+
local_path = tarball_url[7:]
21+
return (
22+
"bash -c '"
23+
"set -e; FW_DIR=/opt/flash-worker; "
24+
"mkdir -p $FW_DIR; "
25+
f"tar xzf {local_path} -C $FW_DIR --strip-components=1; "
26+
"exec $FW_DIR/bootstrap.sh'"
27+
)
28+
29+
return (
30+
"bash -c '"
31+
f"set -e; FW_DIR=/opt/flash-worker; FW_VER={worker_version}; "
32+
# Network volume cache check
33+
'NV_CACHE="/runpod-volume/.flash-worker/v$FW_VER"; '
34+
'if [ -d "$NV_CACHE" ] && [ -f "$NV_CACHE/.version" ]; then '
35+
'cp -r "$NV_CACHE" "$FW_DIR"; '
36+
# Local cache check (container disk persistence between restarts)
37+
'elif [ -f "$FW_DIR/.version" ] && [ "$(cat $FW_DIR/.version)" = "$FW_VER" ]; then '
38+
"true; "
39+
"else "
40+
"mkdir -p $FW_DIR; "
41+
f'DL_URL="{tarball_url}"; '
42+
'(command -v curl >/dev/null 2>&1 && curl -sSL "$DL_URL" || wget -qO- "$DL_URL") '
43+
"| tar xz -C $FW_DIR --strip-components=1; "
44+
# Cache to network volume if available
45+
"if [ -d /runpod-volume ]; then "
46+
'mkdir -p "$NV_CACHE" && cp -r "$FW_DIR"/* "$NV_CACHE/" 2>/dev/null || true; fi; '
47+
"fi; "
48+
"exec $FW_DIR/bootstrap.sh'"
49+
)

src/runpod_flash/core/resources/live_serverless.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,24 @@
88
get_image_name,
99
local_python_version,
1010
)
11+
from .injection import build_injection_cmd
1112
from .load_balancer_sls_resource import (
1213
CpuLoadBalancerSlsResource,
1314
LoadBalancerSlsResource,
1415
)
1516
from .serverless import ServerlessEndpoint
1617
from .serverless_cpu import CpuServerlessEndpoint
18+
from .template import PodTemplate
1719

1820

1921
class LiveServerlessMixin:
20-
"""Common mixin for live serverless endpoints that locks the image."""
22+
"""Configures process injection via dockerArgs for any base image.
23+
24+
Sets a default base image (user can override via imageName) and generates
25+
dockerArgs to download, extract, and run the flash-worker tarball at container
26+
start time. QB vs LB mode is determined by FLASH_ENDPOINT_TYPE env var at
27+
runtime, not by the Docker image.
28+
"""
2129

2230
_image_type: ClassVar[str] = (
2331
"" # Override in subclasses: 'gpu', 'cpu', 'lb', 'lb-cpu'
@@ -42,6 +50,18 @@ def imageName(self):
4250
def imageName(self, value):
4351
pass
4452

53+
def _create_new_template(self) -> PodTemplate:
54+
"""Create template with dockerArgs for process injection."""
55+
template = super()._create_new_template() # type: ignore[misc]
56+
template.dockerArgs = build_injection_cmd()
57+
return template
58+
59+
def _configure_existing_template(self) -> None:
60+
"""Configure existing template, adding dockerArgs for injection if not user-set."""
61+
super()._configure_existing_template() # type: ignore[misc]
62+
if self.template is not None and not self.template.dockerArgs: # type: ignore[attr-defined]
63+
self.template.dockerArgs = build_injection_cmd() # type: ignore[attr-defined]
64+
4565

4666
class LiveServerless(LiveServerlessMixin, ServerlessEndpoint):
4767
"""GPU-only live serverless endpoint."""

tests/integration/test_cpu_disk_sizing.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ def test_live_serverless_cpu_integration(self):
125125
)
126126

127127
# Verify integration:
128-
# 1. Uses CPU image (locked)
128+
# 1. Uses CPU base image (default)
129129
# 2. CPU utilities calculate minimum disk size
130130
# 3. Template creation with auto-sizing
131131
# 4. Validation passes
132-
assert "flash-cpu:" in live_serverless.imageName
132+
assert live_serverless.imageName == "python:3.12-slim"
133133
assert live_serverless.instanceIds == [
134134
CpuInstanceType.CPU5C_1_2,
135135
CpuInstanceType.CPU5C_2_4,
@@ -244,28 +244,24 @@ def test_mixed_cpu_generations_integration(self):
244244
assert "cpu5c-1-2: max 15GB" in error_msg
245245

246246

247-
class TestLiveServerlessImageLockingIntegration:
248-
"""Test image locking integration in live serverless variants."""
247+
class TestLiveServerlessImageDefaultsIntegration:
248+
"""Test image defaults in live serverless variants."""
249249

250-
def test_live_serverless_image_consistency(self):
251-
"""Test that LiveServerless variants maintain image consistency."""
250+
def test_live_serverless_image_defaults(self):
251+
"""Test that LiveServerless variants use correct base images."""
252252
gpu_live = LiveServerless(name="gpu-live")
253253
cpu_live = CpuLiveServerless(name="cpu-live")
254254

255-
# Verify different images are used
255+
# Verify different base images are used
256256
assert gpu_live.imageName != cpu_live.imageName
257-
assert "flash:" in gpu_live.imageName
258-
assert "flash-cpu:" in cpu_live.imageName
257+
assert "pytorch" in gpu_live.imageName
258+
assert "python" in cpu_live.imageName
259259

260-
# Verify images remain locked despite attempts to change
261-
original_gpu_image = gpu_live.imageName
262-
original_cpu_image = cpu_live.imageName
263-
264-
gpu_live.imageName = "custom/image:latest"
265-
cpu_live.imageName = "custom/image:latest"
266-
267-
assert gpu_live.imageName == original_gpu_image
268-
assert cpu_live.imageName == original_cpu_image
260+
# Verify images can be overridden (BYOI)
261+
custom_gpu = LiveServerless(
262+
name="custom-gpu", imageName="nvidia/cuda:12.8.0-runtime"
263+
)
264+
assert custom_gpu.imageName == "nvidia/cuda:12.8.0-runtime"
269265

270266
def test_live_serverless_template_integration(self):
271267
"""Test live serverless template integration with disk sizing."""

tests/integration/test_lb_remote_execution.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,22 +114,18 @@ async def echo(message: str):
114114
# Verify resource is correctly configured
115115
# Note: name may have "-fb" appended by flash boot validator
116116
assert "test-live-api" in lb.name
117-
assert "flash-lb" in lb.imageName
117+
assert "pytorch" in lb.imageName # GPU base image
118118
assert echo.__remote_config__["method"] == "POST"
119119

120-
def test_live_load_balancer_image_locked(self):
121-
"""Test that LiveLoadBalancer locks the image to Flash LB image."""
120+
def test_live_load_balancer_default_image(self):
121+
"""Test that LiveLoadBalancer uses GPU base image by default."""
122122
lb = LiveLoadBalancer(name="test-api")
123+
assert "pytorch" in lb.imageName
123124

124-
# Verify image is locked and cannot be overridden
125-
original_image = lb.imageName
126-
assert "flash-lb" in original_image
127-
128-
# Try to set a different image (should be ignored due to property)
129-
lb.imageName = "custom-image:latest"
130-
131-
# Image should still be locked to Flash
132-
assert lb.imageName == original_image
125+
def test_live_load_balancer_allows_custom_image(self):
126+
"""Test that LiveLoadBalancer allows user to set custom image (BYOI)."""
127+
lb = LiveLoadBalancer(name="test-api", imageName="custom-image:latest")
128+
assert lb.imageName == "custom-image:latest"
133129

134130
def test_load_balancer_vs_queue_based_endpoints(self):
135131
"""Test that LB and QB endpoints have different characteristics."""
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Unit tests for process injection utilities."""
2+
3+
from runpod_flash.core.resources.injection import build_injection_cmd
4+
5+
6+
class TestBuildInjectionCmd:
7+
"""Test build_injection_cmd() output format."""
8+
9+
def test_default_remote_url(self):
10+
"""Test default remote URL generation."""
11+
cmd = build_injection_cmd(worker_version="1.1.1")
12+
13+
assert cmd.startswith("bash -c '")
14+
assert "FW_VER=1.1.1" in cmd
15+
assert "flash-worker/releases/download/v1.1.1/" in cmd
16+
assert "bootstrap.sh'" in cmd
17+
18+
def test_custom_tarball_url(self):
19+
"""Test custom tarball URL."""
20+
url = "https://example.com/worker.tar.gz"
21+
cmd = build_injection_cmd(worker_version="2.0.0", tarball_url=url)
22+
23+
assert "FW_VER=2.0.0" in cmd
24+
assert url in cmd
25+
26+
def test_file_url_for_local_testing(self):
27+
"""Test file:// URL generates local extraction command."""
28+
cmd = build_injection_cmd(
29+
worker_version="1.0.0",
30+
tarball_url="file:///tmp/flash-worker.tar.gz",
31+
)
32+
33+
assert "tar xzf /tmp/flash-worker.tar.gz" in cmd
34+
assert "curl" not in cmd
35+
assert "wget" not in cmd
36+
assert "bootstrap.sh'" in cmd
37+
38+
def test_version_caching_logic(self):
39+
"""Test that version-based cache check is included."""
40+
cmd = build_injection_cmd(worker_version="1.1.1")
41+
42+
# Should check .version file
43+
assert ".version" in cmd
44+
assert "FW_VER" in cmd
45+
46+
def test_network_volume_caching(self):
47+
"""Test network volume cache path is included."""
48+
cmd = build_injection_cmd(worker_version="1.1.1")
49+
50+
assert "/runpod-volume/.flash-worker/" in cmd
51+
assert "NV_CACHE" in cmd
52+
53+
def test_curl_wget_fallback(self):
54+
"""Test curl/wget fallback logic."""
55+
cmd = build_injection_cmd(worker_version="1.0.0")
56+
57+
assert "curl -sSL" in cmd
58+
assert "wget -qO-" in cmd
59+
60+
def test_default_uses_constants(self):
61+
"""Test that calling with no args uses module-level constants."""
62+
from runpod_flash.core.resources.constants import FLASH_WORKER_VERSION
63+
64+
cmd = build_injection_cmd()
65+
66+
assert f"FW_VER={FLASH_WORKER_VERSION}" in cmd
67+
assert f"v{FLASH_WORKER_VERSION}" in cmd
68+
69+
def test_strip_components_in_remote_extraction(self):
70+
"""Test tar uses --strip-components=1 for remote downloads."""
71+
cmd = build_injection_cmd(worker_version="1.0.0")
72+
73+
assert "--strip-components=1" in cmd
74+
75+
def test_strip_components_in_local_extraction(self):
76+
"""Test tar uses --strip-components=1 for local file extraction."""
77+
cmd = build_injection_cmd(
78+
worker_version="1.0.0",
79+
tarball_url="file:///tmp/fw.tar.gz",
80+
)
81+
82+
assert "--strip-components=1" in cmd

0 commit comments

Comments
 (0)