Skip to content

Commit 2a32a62

Browse files
committed
Merge pull request #1353 from rohitc33:rohit-colocated-2
GitOrigin-RevId: efecde8
2 parents c30522c + e3ac888 commit 2a32a62

11 files changed

Lines changed: 702 additions & 61 deletions

File tree

Dockerfile

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
ARG TARGET=base
44
ARG BASE_IMAGE=ubuntu:22.04
5+
ARG BASE_IMAGE_COLOCATED=us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:2025_10_29-python_3.10-jax_0.6.2
56

67
FROM ${BASE_IMAGE} AS base
78

@@ -102,11 +103,39 @@ COPY pyproject.toml README.md /root/
102103
RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean
103104
RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi
104105
RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \
105-
uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \
106+
uv pip install --prerelease=allow "jaxlib==0.6.2.dev20251021" \
106107
--find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
107108
fi
108109
COPY . .
109110

111+
################################################################################
112+
# Colocated Python container spec. #
113+
################################################################################
114+
115+
FROM ${BASE_IMAGE_COLOCATED} AS colocated-python
116+
117+
WORKDIR /app
118+
COPY . .
119+
120+
# Install the additional user-provided dependencies, strictly enforcing the rules
121+
# from the base image's constraints file.
122+
RUN \
123+
# 1. Install user-provided dependencies with modified constraints
124+
grep -v "^numpy" /opt/venv/server_constraints.txt | grep -v "^scipy" > /tmp/modified_constraints.txt && \
125+
echo "--> Installing user-provided dependencies..." && \
126+
uv pip install ".[core,gcp]" -c /tmp/modified_constraints.txt && \
127+
\
128+
# 2. Override numpy and scipy with specific versions
129+
uv pip install numpy==2.1.1 scipy==1.15.3 && \
130+
\
131+
# 3. Verify that the colocated_python_cpu_client is present.
132+
echo "--> Verifying JAX patch integrity..." && \
133+
python -c "from jax._src.lib import _jax; _jax.colocated_python_cpu_client" && \
134+
echo "--> JAX patch verification successful." && \
135+
\
136+
# 4. Clean the cache to keep the image slim.
137+
uv cache clean
138+
110139
################################################################################
111140
# GPU container spec. #
112141
################################################################################

axlearn/cloud/common/bundler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class Config(Bundler.Config):
301301
cache_from: Optional[Sequence[str]] = None
302302
# Skip the build + push step (e.g., using a pre-built image).
303303
skip_bundle: bool = False
304+
# Sidecar names to build images for.
305+
sidecars: list[str] = []
304306

305307
def __init__(self, cfg: Config):
306308
super().__init__(cfg)
@@ -337,6 +339,7 @@ def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> Config
337339
- platform: The image target platform.
338340
- allow_dirty: Whether to ignore dirty git status.
339341
- cache_from: A comma-separated list of cache sources.
342+
- sidecars: A comma-separated list of sidecar names.
340343
- skip_bundle: Whether to skip the build + push. This option is intended to be used when an
341344
image has already been pre-built offline, in which case we may still want to leverage
342345
the install commands implemented by the bundler.
@@ -346,13 +349,15 @@ def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> Config
346349
cfg: BaseDockerBundler.Config = super().from_spec(spec, fv=fv)
347350
kwargs = parse_kv_flags(spec, delimiter="=")
348351
cache_from = canonicalize_to_list(kwargs.pop("cache_from", None))
352+
sidecars = canonicalize_to_list(kwargs.pop("sidecars", None))
349353
skip_bundle = to_bool(kwargs.pop("skip_bundle", False))
350354
allow_dirty = to_bool(kwargs.pop("allow_dirty", False))
351355
# Non-config specs are treated as build args.
352356
build_args = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k not in cfg}
353357
return cfg.set(
354358
build_args=build_args,
355359
cache_from=cache_from,
360+
sidecars=sidecars,
356361
skip_bundle=skip_bundle,
357362
allow_dirty=allow_dirty,
358363
**kwargs,
@@ -485,6 +490,16 @@ def _build_and_push(
485490
labels: dict[str, str],
486491
) -> str:
487492
cfg: DockerBundler.Config = self.config
493+
494+
_, tag = image.rsplit(":", maxsplit=1)
495+
for sidecar in cfg.sidecars:
496+
sidecar_bundler = cfg.set(
497+
image=sidecar,
498+
target=sidecar,
499+
sidecars=[],
500+
).instantiate()
501+
sidecar_bundler.bundle(tag=tag)
502+
488503
return docker_push(
489504
docker_build(
490505
dockerfile=dockerfile,

axlearn/cloud/gcp/bundler.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ class Config(BaseDockerBundler.Config):
129129
from flags.
130130
is_async: Whether to build asynchronously. If True, callers should invoke
131131
`wait_until_finished()` to wait for bundling to complete.
132+
private_worker_pool: If provided, should be the identifier of a private worker pool.
133+
See: https://cloud.google.com/build/docs/private-pools/private-pools-overview
132134
"""
133135

134136
# GCP project.
135137
project: Required[str] = REQUIRED
136138
# Build image asynchronously.
137139
is_async: bool = True
138-
# If provided, should be the identifier of a private worker pool.
139-
# See: https://cloud.google.com/build/docs/private-pools/private-pools-overview
140140
private_worker_pool: Optional[str] = None
141141

142142
@classmethod
@@ -175,9 +175,14 @@ def _build_and_push(
175175
)
176176
image_path, image_tag = image.rsplit(":", maxsplit=1)
177177
latest_tag = f"{image_path}:latest"
178-
cloudbuild_yaml = f"""
179-
steps:
180-
- name: "gcr.io/cloud-builders/docker"
178+
179+
# Build steps - start with main image
180+
build_steps = []
181+
images_list = [f'"{image}"', f'"{latest_tag}"']
182+
183+
# Main image build step
184+
build_steps.append(
185+
f"""- name: "gcr.io/cloud-builders/docker"
181186
args: [
182187
"build",
183188
"-f", "{os.path.relpath(dockerfile, context)}",
@@ -193,11 +198,44 @@ def _build_and_push(
193198
"."
194199
]
195200
env:
196-
- "DOCKER_BUILDKIT=1"
201+
- "DOCKER_BUILDKIT=1\""""
202+
)
203+
204+
# Add sidecar image build steps
205+
for sidecar in cfg.sidecars:
206+
sidecar_target = sidecar
207+
sidecar_image_path = f"{cfg.repo}/{sidecar}"
208+
sidecar_image = f"{sidecar_image_path}:{image_tag}"
209+
sidecar_latest_image = f"{sidecar_image_path}:latest"
210+
211+
build_steps.append(
212+
f"""- name: "gcr.io/cloud-builders/docker"
213+
args: [
214+
"build",
215+
"-f", "{os.path.relpath(dockerfile, context)}",
216+
"-t", "{sidecar_image}",
217+
"-t", "{sidecar_latest_image}",
218+
"--target", "{sidecar_target}",
219+
"--cache-from", "{sidecar_image}",
220+
"--cache-from", "{sidecar_latest_image}",
221+
{cache_from}
222+
{build_platform}
223+
{build_args}
224+
{labels}
225+
"."
226+
]
227+
env:
228+
- "DOCKER_BUILDKIT=1\""""
229+
)
230+
231+
images_list.extend([f'"{sidecar_image}"', f'"{sidecar_latest_image}"'])
232+
233+
cloudbuild_yaml = f"""
234+
steps:
235+
{chr(10).join(build_steps)}
197236
timeout: 3600s
198237
images:
199-
- "{image}"
200-
- "{latest_tag}"
238+
{chr(10).join([f"- {img}" for img in images_list])}
201239
tags: [{image_tag}]
202240
options:
203241
logging: CLOUD_LOGGING_ONLY

0 commit comments

Comments
 (0)