1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16- # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
15+ # This workflow builds and pushes MaxText images for both TPU and GPU devices.
16+ # It runs automatically daily at 12am UTC, on Pull Requests, or manually via Workflow Dispatch.
1717
1818name : Build Images
1919
3434 - gpu
3535
3636jobs :
37- build-tpu :
38- # This job will only run for 'tpu', 'all', schedule, or PR triggers.
39- if : >
40- github.event_name == 'schedule' ||
41- github.event_name == 'pull_request' ||
42- github.event.inputs.target_device == 'all' ||
43- github.event.inputs.target_device == 'tpu'
44-
37+ build :
38+ name : Build ${{ matrix.device }}-${{ matrix.build_mode }} Image
4539 runs-on : linux-x86-n2-16-buildkit
4640 container : google/cloud-sdk:524.0.0
4741
@@ -51,127 +45,91 @@ jobs:
5145 matrix :
5246 include :
5347 # TPU Image Builds
54- - image_name : maxtext_jax_stable
48+ - device : tpu
49+ build_mode : stable
50+ image_name : maxtext_jax_stable
5551 dockerfile : ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
56- build_args : |
57- MODE=stable
58- JAX_VERSION=NONE
59- LIBTPU_GCS_PATH=NONE
60- - image_name : maxtext_jax_nightly
52+ - device : tpu
53+ build_mode : nightly
54+ image_name : maxtext_jax_nightly
6155 dockerfile : ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
62- build_args : |
63- MODE=nightly
64- JAX_VERSION=NONE
65- LIBTPU_GCS_PATH=NONE
66- # TPU Image builds using JAX AI Image
67- - image_name : maxtext_jax_stable_stack
68- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
69- base_image : us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
70- - image_name : maxtext_stable_stack_nightly_jax
71- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
72- base_image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
73- - image_name : maxtext_stable_stack_candidate
74- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
75- base_image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
76-
77- # Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
78- steps :
79- - uses : actions/checkout@v5
80- - name : Mark git repository as safe
81- run : git config --global --add safe.directory ${GITHUB_WORKSPACE}
82- - name : Configure Docker
83- run : gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
84- - name : Set up Docker BuildX
85- uses : docker/setup-buildx-action@v3.11.1
86- with :
87- driver : remote
88- endpoint : tcp://localhost:1234
89- # Env variables to be passed to Dockerfile
90- - name : Get short commit hash
91- id : vars
92- run : echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
93- - name : Get current date
94- id : date
95- run : echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
96- # Docker BuildX command config
97- - name : Build and Push Docker Image
98- uses : docker/build-push-action@v6
99- with :
100- push : true
101- context : .
102- file : ${{ matrix.dockerfile }}
103- tags : |
104- gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date.outputs.image_date }}
105- gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
106- cache-from : type=gha
107- cache-to : type=gha,mode=max
108- provenance : false
109- build-args : |
110- ${{ matrix.build_args }}
111- JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
112- COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
113- DEVICE=tpu
114- TEST_TYPE=xlml
56+ # GPU Image Builds
57+ - device : gpu
58+ build_mode : stable
59+ image_name : maxtext_gpu_jax_stable
60+ dockerfile : ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
61+ - device : gpu
62+ build_mode : nightly
63+ image_name : maxtext_gpu_jax_nightly
64+ dockerfile : ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
11565
116- # Same as tpu-build step but mirrored for GPUs
117- build-gpu :
11866 if : >
11967 github.event_name == 'schedule' ||
12068 github.event_name == 'pull_request' ||
121- github.event.inputs.target_device == 'all' ||
122- github.event.inputs.target_device == 'gpu'
69+ github.event_name == 'workflow_dispatch' && (
70+ github.event.inputs.target_device == 'all' ||
71+ github.event.inputs.target_device == 'tpu' ||
72+ github.event.inputs.target_device == 'gpu'
73+ )
12374
124- runs-on : linux-x86-n2-16-buildkit
125- container : google/cloud-sdk:524.0.0
75+ # Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
76+ steps :
77+ - name : Check if build should run
78+ id : check
79+ shell : bash
80+ run : |
81+ if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${{ github.event.inputs.target_device }}" != "all" && "${{ github.event.inputs.target_device }}" != "${{ matrix.device }}" ]]; then
82+ echo "should_run=false" >> $GITHUB_OUTPUT
83+ echo "Skipping build for device: ${{ matrix.device }} in ${{ matrix.build_mode }} mode."
84+ else
85+ echo "should_run=true" >> $GITHUB_OUTPUT
86+ echo "Building for device: ${{ matrix.device }} in ${{ matrix.build_mode }} mode."
87+ fi
12688
127- strategy :
128- fail-fast : false
129- matrix :
130- # GPU Image Builds using JAX AI Image
131- include :
132- - image_name : maxtext_gpu_jax_stable_stack
133- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
134- base_image : us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:latest
135- - image_name : maxtext_gpu_stable_stack_nightly_jax
136- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
137- base_image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest
138- - image_name : maxtext_stable_stack_candidate_gpu
139- dockerfile : ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
140- base_image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest
89+ - name : Checkout git repository
90+ uses : actions/checkout@v5
91+ if : steps.check.outputs.should_run == 'true'
14192
142- steps :
143- - uses : actions/checkout@v5
14493 - name : Mark git repository as safe
94+ if : steps.check.outputs.should_run == 'true'
14595 run : git config --global --add safe.directory ${GITHUB_WORKSPACE}
96+
14697 - name : Configure Docker
147- run : gcloud auth configure-docker us-docker.pkg.dev,gcr.io,us-central1-docker.pkg.dev -q
98+ if : steps.check.outputs.should_run == 'true'
99+ run : gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
100+
148101 - name : Set up Docker BuildX
149102 uses : docker/setup-buildx-action@v3.11.1
103+ if : steps.check.outputs.should_run == 'true'
150104 with :
151105 driver : remote
152106 endpoint : tcp://localhost:1234
153- - name : Get short commit hash
107+
108+ # Env variables to be passed to Dockerfile
109+ - name : Get metadata
154110 id : vars
155- run : echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
156- - name : Get current date
157- id : date
158- run : echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
111+ if : steps.check.outputs.should_run == 'true'
112+ run : |
113+ echo "commit_hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
114+ echo "image_date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
115+
116+ # Docker BuildX command config
159117 - name : Build and Push Docker Image
160118 uses : docker/build-push-action@v6
119+ if : steps.check.outputs.should_run == 'true'
161120 with :
162121 push : true
163122 context : .
164123 file : ${{ matrix.dockerfile }}
165124 tags : |
166- gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:maxtext_${{ steps.vars.outputs.sha_short }}
167- gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.date .outputs.image_date }}
125+ gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:maxtext_${{ steps.vars.outputs.commit_hash }}
126+ gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:${{ steps.vars .outputs.image_date }}
168127 gcr.io/tpu-prod-env-multipod/${{ matrix.image_name }}:latest
169128 cache-from : type=gha
170129 cache-to : type=gha,mode=max
171130 provenance : false
172131 build-args : |
173- ${{ matrix.build_args }}
174- JAX_AI_IMAGE_BASEIMAGE=${{ matrix.base_image }}
175- COMMIT_HASH=${{ steps.vars.outputs.sha_short }}
176- DEVICE=gpu
177- TEST_TYPE=xlml
132+ DEVICE=${{ matrix.device }}
133+ MODE=${{ matrix.build_mode }}
134+ JAX_VERSION=NONE
135+ LIBTPU_GCS_PATH=NONE
0 commit comments