Skip to content

Commit d39bb99

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-ckpt-opt3
2 parents b21c695 + 093ab89 commit d39bb99

544 files changed

Lines changed: 24265 additions & 7602 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
.git
2+
maxtext_venv

.github/CODEOWNERS

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,34 @@
1-
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google
1+
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google @dipannita08 @igorts-git
22

33
# Model bring-up
4-
src/MaxText/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande
5-
src/MaxText/configs/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
6-
src/MaxText/utils/ckpt_conversion @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @hengtaoguo @gagika @shralex @richjames0 @NicoGrande
7-
src/MaxText/layers @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
4+
src/maxtext/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande
5+
src/maxtext/configs/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @jesselu-google @NuojCheng
6+
src/maxtext/checkpoint_conversion @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @hengtaoguo @gagika @shralex @richjames0 @NicoGrande
7+
src/maxtext/layers @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @jesselu-google @NuojCheng
8+
src/maxtext/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @jesselu-google @NuojCheng
89

910
# Features
1011
src/maxtext/experimental/rl @A9isha @khatwanimohit @xuefgu @gagika @richjames0 @shralex @NicoGrande
11-
src/MaxText/input_pipeline @aireenmei @SurbhiJainUSC @richjames0 @shralex @NicoGrande
12-
src/MaxText/kernels/megablox @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
13-
src/MaxText/kernels/ragged_attention.py @patemotter @vipannalla @richjames0 @shralex
14-
src/MaxText/layers/pipeline.py @gobbleturk @richjames0 @shralex
15-
src/MaxText/layers/moe.py @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
16-
src/MaxText/layers/multi_token_prediction.py @parambole @RissyRan @gagika @richjames0 @shralex
17-
src/MaxText/elastic_train.py @lukebaumann @shauryagup @richjames0 @shralex
18-
src/MaxText/layers/quantizations.py @khatwanimohit @jshin1394 @liudangyi @richjames0 @shralex
12+
src/maxtext/input_pipeline @aireenmei @SurbhiJainUSC @richjames0 @shralex @NicoGrande
13+
src/maxtext/kernels/megablox @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
14+
src/maxtext/kernels/ragged_attention.py @patemotter @vipannalla @richjames0 @shralex
15+
src/maxtext/layers/pipeline.py @gobbleturk @richjames0 @shralex @NuojCheng
16+
src/maxtext/layers/moe.py @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
17+
src/maxtext/layers/multi_token_prediction.py @parambole @RissyRan @gagika @richjames0 @shralex
18+
src/maxtext/layers/quantizations.py @khatwanimohit @jshin1394 @liudangyi @richjames0 @shralex
1919

2020
# Inference
21-
src/maxtext/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
21+
tests/inference/ @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
2222
src/maxtext/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
23-
src/maxtext/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
2423

2524
# Dockerfiles and dependencies
26-
*.Dockerfile @bvandermoon @parambole @richjames0 @shralex
27-
*.txt @bvandermoon @parambole @richjames0 @shralex
25+
src/dependencies/ @bvandermoon @parambole @richjames0 @shralex
2826

2927
# Docs
30-
*.md @jacoguzo @bvandermoon @richjames0 @shralex @gobbleturk @RissyRan @gagika @A9isha @jiangjy1982 @vipannalla
28+
docs/ @jacoguzo @bvandermoon @richjames0 @shralex @gobbleturk @RissyRan @gagika @A9isha @jiangjy1982 @vipannalla
3129

3230
# Workflow files
33-
.github/workflows @gobbleturk @khatwanimohit @shralex @parambole @bvandermoon @richjames0
31+
.github/workflows/ @gobbleturk @khatwanimohit @shralex @parambole @bvandermoon @richjames0
3432

3533
# Benchmarking/Recipes
36-
benchmarks @SujeethJinesh @bvandermoon @richjames0 @shralex @vipannalla @mitalisi @RissyRan @shauryagup @NuojCheng @gobbleturk @khatwanimohit @Obliviour @notabee @suexu1025
34+
benchmarks/ @SujeethJinesh @bvandermoon @richjames0 @shralex @vipannalla @mitalisi @RissyRan @shauryagup @NuojCheng @gobbleturk @khatwanimohit @Obliviour @notabee @suexu1025

.github/workflows/UploadDockerImages.yml

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ jobs:
6565
- device: tpu
6666
build_mode: stable
6767
image_name: maxtext_jax_stable
68-
dockerfile: ./dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile
68+
dockerfile: ./src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile
6969
- device: tpu
7070
build_mode: nightly
7171
image_name: maxtext_jax_nightly
72-
dockerfile: ./dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile
72+
dockerfile: ./src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile
7373
uses: ./.github/workflows/build_and_push_docker_image.yml
7474
with:
7575
image_name: ${{ matrix.image_name }}
@@ -79,31 +79,18 @@ jobs:
7979
maxtext_sha: ${{ needs.setup.outputs.maxtext_sha }}
8080
image_date: ${{ needs.setup.outputs.image_date }}
8181

82-
tpu-post-training:
83-
name: ${{ matrix.image_name }}
84-
needs: [setup, tpu-pre-training]
85-
strategy:
86-
fail-fast: false
87-
matrix:
88-
include:
89-
- device: tpu
90-
build_mode: post-training
91-
image_name: maxtext_post_training_stable
92-
dockerfile: ./dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile
93-
- device: tpu
94-
build_mode: post-training
95-
image_name: maxtext_post_training_nightly
96-
dockerfile: ./dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile
82+
tpu-post-training-nightly:
83+
name: tpu-post-training-nightly
84+
needs: [setup]
9785
uses: ./.github/workflows/build_and_push_docker_image.yml
9886
with:
99-
image_name: ${{ matrix.image_name }}
100-
device: ${{ matrix.device }}
101-
build_mode: ${{ matrix.build_mode }}
102-
dockerfile: ${{ matrix.dockerfile }}
87+
image_name: maxtext_post_training_nightly
88+
device: tpu
89+
build_mode: nightly
90+
workflow: post-training
91+
dockerfile: ./src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile
10392
maxtext_sha: ${{ needs.setup.outputs.maxtext_sha }}
10493
image_date: ${{ needs.setup.outputs.image_date }}
105-
base_image: gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:${{ needs.setup.outputs.image_date }}
106-
is_post_training: true
10794

10895
gpu-pre-training:
10996
name: ${{ matrix.image_name }}
@@ -115,11 +102,11 @@ jobs:
115102
- device: gpu
116103
build_mode: stable
117104
image_name: maxtext_gpu_jax_stable
118-
dockerfile: ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
105+
dockerfile: ./src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
119106
- device: gpu
120107
build_mode: nightly
121108
image_name: maxtext_gpu_jax_nightly
122-
dockerfile: ./dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
109+
dockerfile: ./src/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile
123110
uses: ./.github/workflows/build_and_push_docker_image.yml
124111
with:
125112
image_name: ${{ matrix.image_name }}

.github/workflows/build_and_push_docker_image.yml

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ on:
3535
required: true
3636
type: string
3737
image_date:
38-
required: true
38+
required: false
3939
type: string
40-
base_image:
40+
workflow:
4141
required: false
4242
type: string
43-
default: ''
44-
is_post_training:
43+
default: 'pre-training'
44+
version_name:
4545
required: false
46-
type: boolean
47-
default: false
46+
type: string
47+
default: ''
4848

4949
permissions:
5050
contents: read
@@ -66,13 +66,18 @@ jobs:
6666
id: check
6767
shell: bash
6868
run: |
69-
if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${{ github.event.inputs.target_device }}" != "all" && "${{ github.event.inputs.target_device }}" != "${{ inputs.device }}" ]]; then
69+
if [[ "${{ github.event_name }}" == "workflow_dispatch" && "${GITHUB_EVENT_INPUTS_TARGET_DEVICE}" != "all" && "${GITHUB_EVENT_INPUTS_TARGET_DEVICE}" != "${INPUTS_DEVICE}" ]]; then
7070
echo "should_run=false" >> $GITHUB_OUTPUT
71-
echo "Skipping ${{ inputs.image_name }} build for device: ${{ inputs.device }} in ${{ inputs.build_mode }} mode."
71+
echo "Skipping ${INPUTS_IMAGE_NAME} build for device: ${INPUTS_DEVICE} in ${INPUTS_BUILD_MODE} mode."
7272
else
7373
echo "should_run=true" >> $GITHUB_OUTPUT
74-
echo "Building ${{ inputs.image_name }} for device: ${{ inputs.device }} in ${{ inputs.build_mode }} mode."
74+
echo "Building ${INPUTS_IMAGE_NAME} for device: ${INPUTS_DEVICE} in ${INPUTS_BUILD_MODE} mode."
7575
fi
76+
env:
77+
GITHUB_EVENT_INPUTS_TARGET_DEVICE: ${{ github.event.inputs.target_device }}
78+
INPUTS_DEVICE: ${{ inputs.device }}
79+
INPUTS_IMAGE_NAME: ${{ inputs.image_name }}
80+
INPUTS_BUILD_MODE: ${{ inputs.build_mode }}
7681

7782
- name: Checkout MaxText
7883
uses: actions/checkout@v5
@@ -110,40 +115,50 @@ jobs:
110115
push: true
111116
context: .
112117
file: ${{ inputs.dockerfile }}
113-
tags: gcr.io/tpu-prod-env-multipod/${{ inputs.image_name }}:latest
118+
tags: gcr.io/tpu-prod-env-multipod/${{ inputs.image_name }}:${{ github.run_id }}
114119
cache-from: type=gha
115120
outputs: type=image,compression=zstd,force-compression=true
116121
build-args: |
117122
DEVICE=${{ inputs.device }}
118123
MODE=${{ inputs.build_mode }}
124+
WORKFLOW=${{ inputs.workflow }}
119125
JAX_VERSION=NONE
120126
LIBTPU_VERSION=NONE
121127
INCLUDE_TEST_ASSETS=true
122-
${{ inputs.base_image != '' && format('BASEIMAGE={0}', inputs.base_image) || '' }}
123128
124129
- name: Add tags to Docker image
125130
if: steps.check.outputs.should_run == 'true'
126131
shell: bash
127132
run: |
128-
SOURCE_IMAGE="gcr.io/tpu-prod-env-multipod/${{ inputs.image_name }}"
133+
SOURCE_IMAGE="gcr.io/tpu-prod-env-multipod/${INPUTS_IMAGE_NAME}"
129134
130-
# Add date tag
131-
gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:${{ inputs.image_date }}" --quiet
135+
if [[ $INPUTS_VERSION_NAME ]]; then
136+
echo "Tagging docker images corresponding to PyPI release..."
137+
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:${INPUTS_VERSION_NAME}" --quiet
138+
else
139+
echo "Tagging docker images corresponding to nightly release..."
132140
133-
# Convert date to YYYYMMDD format
134-
clean_date=$(echo "${{ inputs.image_date }}" | sed 's/[-:]//g' | cut -c1-8)
141+
# Add date tag
142+
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:${INPUTS_IMAGE_DATE}" --quiet
135143
136-
# Add MaxText tag
137-
maxtext_hash=$(git rev-parse --short HEAD)
138-
gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:maxtext_${maxtext_hash}_${clean_date}" --quiet
144+
# Convert date to YYYYMMDD format
145+
clean_date=$(echo "${INPUTS_IMAGE_DATE}" | sed 's/[-:]//g' | cut -c1-8)
139146
147+
# Add MaxText tag
148+
maxtext_hash=$(git rev-parse --short HEAD)
149+
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:maxtext_${maxtext_hash}_${clean_date}" --quiet
140150
141151
# Add post-training dependencies tags
142-
if [ "${{ inputs.is_post_training }}" == "true" ]; then
152+
if [ "${{ inputs.workflow }}" == "post-training" ]; then
143153
for dir in tunix vllm tpu-inference; do
144154
if [ -d "./$dir" ]; then
145155
dir_hash=$(git -C "$dir" rev-parse --short HEAD)
146-
gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:${dir}_${dir_hash}_${clean_date}" --quiet
147-
fi
148-
done
156+
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:${dir}_${dir_hash}_${clean_date}" --quiet
157+
fi
158+
done
159+
fi
149160
fi
161+
env:
162+
INPUTS_IMAGE_NAME: ${{ inputs.image_name }}
163+
INPUTS_IMAGE_DATE: ${{ inputs.image_date }}
164+
INPUTS_VERSION_NAME: ${{ inputs.version_name }}

0 commit comments

Comments
 (0)