Skip to content

Commit 2a03ed1

Browse files
committed
Update JAX to 0.9.2 for post-training
1 parent 3ebb185 commit 2a03ed1

9 files changed

Lines changed: 478 additions & 466 deletions

File tree

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ jobs:
105105
106106
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
107107
filename=$(basename "$notebook")
108-
if [[ "$filename" == "sft_llama3_demo_gpu.ipynb" || "$filename" == "maxtext_with_gepa.ipynb" ]]; then
108+
# TODO: Update runnner to v6e-8 as RL with LLama3.1-8b doesn't fit on v6e-4
109+
if [[ "$filename" == "sft_llama3_demo_gpu.ipynb" || "$filename" == "maxtext_with_gepa.ipynb" || "$filename" == "rl_llama3_demo.ipynb" ]]; then
109110
echo "Skipping $filename"
110111
continue
111112
fi

.github/workflows/run_tests_coordinator.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,24 +102,24 @@ jobs:
102102
${{ fromJSON('{
103103
"tpu-unit": "not cpu_only and not gpu_only and not integration_test and not post_training",
104104
"tpu-integration": "not cpu_only and not gpu_only and integration_test and not post_training",
105-
"tpu-post-training-unit": "not cpu_only and not gpu_only and not integration_test",
105+
"tpu-post-training-unit": "not cpu_only and not gpu_only and not integration_test and post_training",
106106
"tpu-post-training-integration": "not cpu_only and not gpu_only and integration_test",
107107
"gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training",
108108
"gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training",
109109
"cpu-unit": "cpu_only and not post_training",
110-
"cpu-post-training-unit": "cpu_only"
110+
"cpu-post-training-unit": "cpu_only and post_training"
111111
}')[inputs.flavor] }}
112112
113113
pytest_addopts: >-
114114
${{ fromJSON('{
115115
"tpu-unit": "",
116116
"tpu-integration": "",
117-
"tpu-post-training-unit": "tests/post_training/unit",
117+
"tpu-post-training-unit": "tests/post_training/unit tests/unit",
118118
"tpu-post-training-integration": "tests/post_training/integration",
119119
"gpu-unit": "",
120120
"gpu-integration": "",
121121
"cpu-unit": "",
122-
"cpu-post-training-unit": "tests/post_training/unit"
122+
"cpu-post-training-unit": "tests/post_training/unit tests/unit"
123123
}')[inputs.flavor] }}
124124
125125
pytest_extra_args: >-
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-tunix @ https://github.com/google/tunix/archive/336d102fe32ca0edbe42a8f66ff0fd533cebdf52.zip
1+
google-tunix @ https://github.com/google/tunix/archive/f0102a7b0dccc0020503c0617869883f16b3b4ed.zip
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-r post_train_base_deps.txt
22
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
33
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
4-
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/0cae84fc9a883ba1bde02d4f07930e6af9e92958.zip
5-
vllm @ git+https://github.com/vllm-project/vllm@ee8a29511fc69e3f0f6291fa6ff1cf6e47f7750d
4+
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/40876e81f04226f9b7b1e4bbdc9051d6b1364b9d.zip
5+
vllm @ git+https://github.com/vllm-project/vllm@595562651a5a4539ffa910d8570c08fb5169bdc9

src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ isort>=7.0.0
140140
jaraco.classes>=3.4.0
141141
jaraco.context>=6.1.0
142142
jaraco.functools>=4.3.0
143-
jax>=0.8.3
144-
jaxlib>=0.8.3
143+
jax>=0.9.2
144+
jaxlib>=0.9.2
145145
jaxtyping>=0.3.3
146146
jedi>=0.19.2
147147
jeepney>=0.9.0
@@ -164,7 +164,7 @@ lark>=1.2.2
164164
latex2sympy2_extended>=1.11.0
165165
libclang>=18.1.1
166166
libcst>=1.8.6
167-
libtpu>=0.0.32
167+
libtpu>=0.0.39
168168
llguidance>=1.3.0
169169
llvmlite>=0.45.1
170170
lm-format-enforcer>=0.11.3
@@ -180,7 +180,7 @@ matplotlib-inline>=0.2.1
180180
mccabe>=0.7.0
181181
mcp>=1.26.0
182182
mdurl>=0.1.2
183-
mistral_common>=1.9.1
183+
mistral_common>=1.11.0
184184
ml_collections>=1.1.0
185185
ml_dtypes>=0.5.4
186186
ml_goodput_measurement>=0.0.15
@@ -250,7 +250,7 @@ parameterized>=0.9.0
250250
parso>=0.8.6
251251
partial-json-parser>=0.2.1.1.post7
252252
pathspec>=0.12.1
253-
pathwaysutils>=0.1.4
253+
pathwaysutils>=0.1.7
254254
perfetto>=0.16.0
255255
pexpect>=4.9.0
256256
pillow>=12.0.0
@@ -355,19 +355,19 @@ tensorflow-text>=2.20.0
355355
tensorstore>=0.1.79
356356
termcolor>=3.2.0
357357
tiktoken>=0.12.0
358-
tokamax>=0.0.8
358+
tokamax>=0.0.12
359359
tokenizers>=0.22.1
360360
toml>=0.10.2
361361
tomlkit>=0.13.3
362362
toolz>=1.1.0
363-
torch>=2.9.0
363+
torch==2.10.0
364364
torchax>=0.0.11
365-
torchvision>=0.24.0
365+
torchvision==0.25.0
366366
tornado>=6.5.4
367367
tpu-info>=0.7.1
368368
tqdm>=4.67.3
369369
traitlets>=5.14.3
370-
transformers>=4.57.1
370+
transformers>=5.5.4
371371
treescope>=0.1.10
372372
triton>=3.5.0
373373
typeguard>=2.13.3

src/maxtext/examples/rl_llama3_demo.ipynb

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,28 @@
135135
"execution_count": null,
136136
"metadata": {},
137137
"outputs": [],
138-
"source": "import datetime\nimport os\nimport sys\nimport subprocess\nfrom pathlib import Path\nfrom huggingface_hub import login\nfrom etils import epath\nimport jax\n\nfrom maxtext.trainers.post_train.rl.train_rl import rl_train\nfrom maxtext.utils.model_creation_utils import setup_configs_and_devices\nfrom maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n\nos.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\nos.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n# Suppress vLLM logging with a severity level below ERROR\nos.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n\n\nprint(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
138+
"source": [
139+
"import datetime\n",
140+
"import os\n",
141+
"import sys\n",
142+
"import subprocess\n",
143+
"from pathlib import Path\n",
144+
"from huggingface_hub import login\n",
145+
"from etils import epath\n",
146+
"import jax\n",
147+
"\n",
148+
"from maxtext.trainers.post_train.rl.train_rl import rl_train\n",
149+
"from maxtext.utils.model_creation_utils import setup_configs_and_devices\n",
150+
"from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n",
151+
"\n",
152+
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
153+
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
154+
"# Suppress vLLM logging with a severity level below ERROR\n",
155+
"os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
156+
"\n",
157+
"\n",
158+
"print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
159+
]
139160
},
140161
{
141162
"cell_type": "code",
@@ -188,8 +209,7 @@
188209
"metadata": {},
189210
"outputs": [],
190211
"source": [
191-
"MODEL_NAME = \"llama3.1-8b\"\n",
192-
"TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
212+
"MODEL_NAME = \"llama3.1-8b-Instruct\"\n",
193213
"RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
194214
"LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n",
195215
"\n",
@@ -270,35 +290,25 @@
270290
"metadata": {},
271291
"outputs": [],
272292
"source": [
273-
"# Load configuration for RL training\n",
293+
"# Configuration for RL training\n",
274294
"config_argv = [\n",
275295
" \"\",\n",
276296
" f\"{MAXTEXT_PKG_DIR}/configs/post_train/rl.yml\",\n",
277297
" f\"model_name={MODEL_NAME}\",\n",
278-
" f\"tokenizer_path={TOKENIZER_PATH}\",\n",
279298
" f\"run_name={RUN_NAME}\",\n",
280299
" f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
281300
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
282301
" f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
283302
" f\"hf_access_token={HF_TOKEN}\",\n",
284303
" \"debug.rl=False\",\n",
285304
" f\"rl.loss_algo={LOSS_ALGO}\",\n",
286-
" \"use_pathways=False\"\n",
305+
" \"use_pathways=False\",\n",
306+
" \"log_config=False\",\n",
287307
"]\n",
288308
"\n",
289-
"trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n",
290-
"\n",
291-
"rl_train_steps = int(\n",
292-
" trainer_config.num_batches\n",
293-
" * trainer_config.rl.num_iterations\n",
294-
" * trainer_config.train_fraction\n",
295-
" * trainer_config.num_epoch\n",
296-
")\n",
297-
"\n",
298309
"print(\"✓ Configuration initialized successfully\")\n",
299-
"print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n",
300-
"print(f\"🤖 Model: {trainer_config.model_name}\")\n",
301-
"print(f\"📊 RL Train Steps: {rl_train_steps}\")"
310+
"print(f\"📁 Output directory: {OUTPUT_DIRECTORY}\")\n",
311+
"print(f\"🤖 Model: {MODEL_NAME}\")"
302312
]
303313
},
304314
{
@@ -314,23 +324,22 @@
314324
"metadata": {},
315325
"outputs": [],
316326
"source": [
327+
"import traceback\n",
328+
"\n",
317329
"print(\"\\n\" + \"=\" * 80)\n",
318330
"print(f\"🚀 Starting {LOSS_ALGO} Training...\")\n",
319331
"print(\"=\" * 80)\n",
320332
"try:\n",
321-
" rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n",
333+
" rl_train(argv=config_argv, kwargs={})\n",
322334
" print(\"\\n\" + \"=\" * 80)\n",
323335
" print(\"✅ Training Completed Successfully!\")\n",
324-
" print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n",
325336
" print(\"=\" * 80)\n",
326-
" print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n",
327-
" print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n",
328-
" print(f\"🎯 Model ready for inference!\")\n",
329-
"except Exception as e:\n",
337+
"except Exception:\n",
330338
" print(\"\\n\" + \"=\" * 80)\n",
331339
" print(\"❌Training Failed!\")\n",
332340
" print(\"=\" * 80)\n",
333-
" print(f\"Error: {str(e)}\")"
341+
" traceback.print_exc()\n",
342+
" sys.exit(1)"
334343
]
335344
},
336345
{
@@ -347,7 +356,7 @@
347356
],
348357
"metadata": {
349358
"kernelspec": {
350-
"display_name": ".venv",
359+
"display_name": "Python 3",
351360
"language": "python",
352361
"name": "python3"
353362
},
@@ -361,9 +370,9 @@
361370
"name": "python",
362371
"nbconvert_exporter": "python",
363372
"pygments_lexer": "ipython3",
364-
"version": "3.10.12"
373+
"version": "3.12.11"
365374
}
366375
},
367376
"nbformat": 4,
368377
"nbformat_minor": 4
369-
}
378+
}

0 commit comments

Comments
 (0)