Skip to content

Commit 4e927f5

Browse files
Merge pull request #2857 from AI-Hypercomputer:ci_image
PiperOrigin-RevId: 846918748
2 parents 6abcea0 + eddf738 commit 4e927f5

2 files changed

Lines changed: 67 additions & 104 deletions

File tree

.github/workflows/UploadDockerImages.yml

Lines changed: 62 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
15+
# This workflow builds and pushes MaxText images for both TPU and GPU devices.
16+
# It runs automatically daily at 12am UTC, on Pull Requests, or manually via Workflow Dispatch.
1717

1818
name: Build Images
1919

@@ -34,14 +34,8 @@ on:
3434
- gpu
3535

3636
jobs:
37-
build-tpu:
38-
# This job will only run for 'tpu', 'all', schedule, or PR triggers.
39-
if: >
40-
github.event_name == 'schedule' ||
41-
github.event_name == 'pull_request' ||
42-
github.event.inputs.target_device == 'all' ||
43-
github.event.inputs.target_device == 'tpu'
44-
37+
build:
38+
name: Build ${{ matrix.device }}-${{ matrix.build_mode }} Image
4539
runs-on: linux-x86-n2-16-buildkit
4640
container: google/cloud-sdk:524.0.0
4741

@@ -51,127 +45,91 @@ jobs:
5145
matrix:
5246
include:
5347
# TPU Image Builds
54-
- image_name: maxtext_jax_stable
48+
- device: tpu
49+
build_mode: stable
50+
image_name: maxtext_jax_stable
5551
dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
56-
build_args: |
57-
MODE=stable
58-
JAX_VERSION=NONE
59-
LIBTPU_GCS_PATH=NONE
60-
- image_name: maxtext_jax_nightly
52+
- device: tpu
53+
build_mode: nightly
54+
image_name: maxtext_jax_nightly
6155
dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
62-
build_args: |
63-
MODE=nightly
64-
JAX_VERSION=NONE
65-
LIBTPU_GCS_PATH=NONE
66-
# TPU Image builds using JAX AI Image
67-
- image_name: maxtext_jax_stable_stack
68-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
69-
base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
70-
- image_name: maxtext_stable_stack_nightly_jax
71-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
72-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
73-
- image_name: maxtext_stable_stack_candidate
74-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
75-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
76-
77-
# Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
78-
steps:
79-
- uses: actions/checkout@v5
80-
- name: Mark git repository as safe
81-
run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
82-
- name: Configure Docker
83-
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
84-
- name: Set up Docker BuildX
85-
uses: docker/setup-buildx-action@v3.11.1
86-
with:
87-
driver: remote
88-
endpoint: tcp://localhost:1234
89-
# Env variables to be passed to Dockerfile
90-
- name: Get short commit hash
91-
id: vars
92-
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
93-
- name: Get current date
94-
id: date
95-
run: echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
96-
# Docker BuildX command config
97-
- name: Build and Push Docker Image
98-
uses: docker/build-push-action@v6
99-
with:
100-
push: true
101-
context: .
102-
file: ${{ matrix.dockerfile }}
103-
tags: |
104-
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date.outputs.image_date }}
105-
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
106-
cache-from: type=gha
107-
cache-to: type=gha,mode=max
108-
provenance: false
109-
build-args: |
110-
${{ matrix.build_args }}
111-
JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
112-
COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
113-
DEVICE=tpu
114-
TEST_TYPE=xlml
56+
# GPU Image Builds
57+
- device: gpu
58+
build_mode: stable
59+
image_name: maxtext_gpu_jax_stable
60+
dockerfile: ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
61+
- device: gpu
62+
build_mode: nightly
63+
image_name: maxtext_gpu_jax_nightly
64+
dockerfile: ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
11565

116-
# Same as tpu-build step but mirrored for GPUs
117-
build-gpu:
11866
if: >
11967
github.event_name == 'schedule' ||
12068
github.event_name == 'pull_request' ||
121-
github.event.inputs.target_device == 'all' ||
122-
github.event.inputs.target_device == 'gpu'
69+
github.event_name == 'workflow_dispatch' && (
70+
github.event.inputs.target_device == 'all' ||
71+
github.event.inputs.target_device == 'tpu' ||
72+
github.event.inputs.target_device == 'gpu'
73+
)
12374
124-
runs-on: linux-x86-n2-16-buildkit
125-
container: google/cloud-sdk:524.0.0
75+
# Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
76+
steps:
77+
- name: Check if build should run
78+
id: check
79+
shell: bash
80+
run: |
81+
if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${{ github.event.inputs.target_device }}" != "all" && "${{ github.event.inputs.target_device }}" != "${{ matrix.device }}" ]]; then
82+
echo "should_run=false" >> $GITHUB_OUTPUT
83+
echo "Skipping build for device: ${{ matrix.device }} in ${{ matrix.build_mode }} mode."
84+
else
85+
echo "should_run=true" >> $GITHUB_OUTPUT
86+
echo "Building for device: ${{ matrix.device }} in ${{ matrix.build_mode }} mode."
87+
fi
12688
127-
strategy:
128-
fail-fast: false
129-
matrix:
130-
# GPU Image Builds using JAX AI Image
131-
include:
132-
- image_name: maxtext_gpu_jax_stable_stack
133-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
134-
base_image: us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:latest
135-
- image_name: maxtext_gpu_stable_stack_nightly_jax
136-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
137-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest
138-
- image_name: maxtext_stable_stack_candidate_gpu
139-
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
140-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest
89+
- name: Checkout git repository
90+
uses: actions/checkout@v5
91+
if: steps.check.outputs.should_run == 'true'
14192

142-
steps:
143-
- uses: actions/checkout@v5
14493
- name: Mark git repository as safe
94+
if: steps.check.outputs.should_run == 'true'
14595
run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
96+
14697
- name: Configure Docker
147-
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io,us-central1-docker.pkg.dev -q
98+
if: steps.check.outputs.should_run == 'true'
99+
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
100+
148101
- name: Set up Docker BuildX
149102
uses: docker/setup-buildx-action@v3.11.1
103+
if: steps.check.outputs.should_run == 'true'
150104
with:
151105
driver: remote
152106
endpoint: tcp://localhost:1234
153-
- name: Get short commit hash
107+
108+
# Env variables to be passed to Dockerfile
109+
- name: Get metadata
154110
id: vars
155-
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
156-
- name: Get current date
157-
id: date
158-
run: echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
111+
if: steps.check.outputs.should_run == 'true'
112+
run: |
113+
echo "commit_hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
114+
echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
115+
116+
# Docker BuildX command config
159117
- name: Build and Push Docker Image
160118
uses: docker/build-push-action@v6
119+
if: steps.check.outputs.should_run == 'true'
161120
with:
162121
push: true
163122
context: .
164123
file: ${{ matrix.dockerfile }}
165124
tags: |
166-
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:maxtext_${{ steps.vars.outputs.sha_short }}
167-
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date.outputs.image_date }}
125+
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:maxtext_${{ steps.vars.outputs.commit_hash }}
126+
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.vars.outputs.image_date }}
168127
gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
169128
cache-from: type=gha
170129
cache-to: type=gha,mode=max
171130
provenance: false
172131
build-args: |
173-
${{ matrix.build_args }}
174-
JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
175-
COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
176-
DEVICE=gpu
177-
TEST_TYPE=xlml
132+
DEVICE=${{ matrix.device }}
133+
MODE=${{ matrix.build_mode }}
134+
JAX_VERSION=NONE
135+
LIBTPU_GCS_PATH=NONE

dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
33
FROM $BASEIMAGE
44

5+
# Move the 'EXTERNALLY-MANAGED' file to allow system-wide pip installs
6+
RUN if [ -f /usr/lib/python3.12/EXTERNALLY-MANAGED ]; then \
7+
mv /usr/lib/python3.12/EXTERNALLY-MANAGED /usr/lib/python3.12/EXTERNALLY-MANAGED.old; \
8+
fi
9+
510
# Stopgaps measure to circumvent gpg key setup issue.
611
RUN echo "deb [trusted=yes] https://developer.download.nvidia.com/devtools/repos/ubuntu2204/amd64/ /" > /etc/apt/sources.list.d/devtools-ubuntu2204-amd64.list
712

0 commit comments

Comments
 (0)