135135 "execution_count" : null ,
136136 "metadata" : {},
137137 "outputs" : [],
138- "source" : " import datetime\n import os\n import sys\n import subprocess\n from pathlib import Path\n from huggingface_hub import login\n from etils import epath\n import jax\n\n from maxtext.trainers.post_train.rl.train_rl import rl_train\n from maxtext.utils.model_creation_utils import setup_configs_and_devices\n from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n\n os.environ[\" TF_CPP_MIN_LOG_LEVEL\" ] = \" 0\"\n os.environ[\" SKIP_JAX_PRECOMPILE\" ] = \" 1\" # Faster startup for vLLM\n # Suppress vLLM logging with a severity level below ERROR\n os.environ[\" VLLM_LOGGING_LEVEL\" ] = \" ERROR\"\n\n\n print(f\" MaxText installation path: {MAXTEXT_PKG_DIR}\" )"
138+ "source" : [
139+ " import datetime\n " ,
140+ " import os\n " ,
141+ " import sys\n " ,
142+ " import subprocess\n " ,
143+ " from pathlib import Path\n " ,
144+ " from huggingface_hub import login\n " ,
145+ " from etils import epath\n " ,
146+ " import jax\n " ,
147+ " \n " ,
148+ " from maxtext.trainers.post_train.rl.train_rl import rl_train\n " ,
149+ " from maxtext.utils.model_creation_utils import setup_configs_and_devices\n " ,
150+ " from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n " ,
151+ " \n " ,
152+ " os.environ[\" TF_CPP_MIN_LOG_LEVEL\" ] = \" 0\"\n " ,
153+ " os.environ[\" SKIP_JAX_PRECOMPILE\" ] = \" 1\" # Faster startup for vLLM\n " ,
154+ " # Suppress vLLM logging with a severity level below ERROR\n " ,
155+ " os.environ[\" VLLM_LOGGING_LEVEL\" ] = \" ERROR\"\n " ,
156+ " \n " ,
157+ " \n " ,
158+ " print(f\" MaxText installation path: {MAXTEXT_PKG_DIR}\" )"
159+ ]
139160 },
140161 {
141162 "cell_type" : " code" ,
188209 "metadata" : {},
189210 "outputs" : [],
190211 "source" : [
191- " MODEL_NAME = \" llama3.1-8b\"\n " ,
192- " TOKENIZER_PATH = \" meta-llama/Llama-3.1-8B-Instruct\"\n " ,
212+ " MODEL_NAME = \" llama3.1-8b-Instruct\"\n " ,
193213 " RUN_NAME = datetime.datetime.now().strftime(\" %Y-%m-%d-%H-%M-%S\" )\n " ,
194214 " LOSS_ALGO=\" grpo\" # or \" gspo-token\" if you want to use GSPO\n " ,
195215 " \n " ,
270290 "metadata" : {},
271291 "outputs" : [],
272292 "source" : [
273- " # Load configuration for RL training\n " ,
293+ " # Configuration for RL training\n " ,
274294 " config_argv = [\n " ,
275295 " \"\" ,\n " ,
276296 " f\" {MAXTEXT_PKG_DIR}/configs/post_train/rl.yml\" ,\n " ,
277297 " f\" model_name={MODEL_NAME}\" ,\n " ,
278- " f\" tokenizer_path={TOKENIZER_PATH}\" ,\n " ,
279298 " f\" run_name={RUN_NAME}\" ,\n " ,
280299 " f\" chat_template_path={CHAT_TEMPLATE_PATH}\" ,\n " ,
281300 " f\" load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\" ,\n " ,
282301 " f\" base_output_directory={OUTPUT_DIRECTORY}\" ,\n " ,
283302 " f\" hf_access_token={HF_TOKEN}\" ,\n " ,
284303 " \" debug.rl=False\" ,\n " ,
285304 " f\" rl.loss_algo={LOSS_ALGO}\" ,\n " ,
286- " \" use_pathways=False\"\n " ,
305+ " \" use_pathways=False\" ,\n " ,
306+ " \" log_config=False\" ,\n " ,
287307 " ]\n " ,
288308 " \n " ,
289- " trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n " ,
290- " \n " ,
291- " rl_train_steps = int(\n " ,
292- " trainer_config.num_batches\n " ,
293- " * trainer_config.rl.num_iterations\n " ,
294- " * trainer_config.train_fraction\n " ,
295- " * trainer_config.num_epoch\n " ,
296- " )\n " ,
297- " \n " ,
298309 " print(\" ✓ Configuration initialized successfully\" )\n " ,
299- " print(f\" 📁 Output directory: {trainer_config.base_output_directory}\" )\n " ,
300- " print(f\" 🤖 Model: {trainer_config.model_name}\" )\n " ,
301- " print(f\" 📊 RL Train Steps: {rl_train_steps}\" )"
310+ " print(f\" 📁 Output directory: {OUTPUT_DIRECTORY}\" )\n " ,
311+ " print(f\" 🤖 Model: {MODEL_NAME}\" )"
302312 ]
303313 },
304314 {
314324 "metadata" : {},
315325 "outputs" : [],
316326 "source" : [
327+ " import traceback\n " ,
328+ " \n " ,
317329 " print(\"\\ n\" + \" =\" * 80)\n " ,
318330 " print(f\" 🚀 Starting {LOSS_ALGO} Training...\" )\n " ,
319331 " print(\" =\" * 80)\n " ,
320332 " try:\n " ,
321- " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices )\n " ,
333+ " rl_train(argv=config_argv, kwargs={} )\n " ,
322334 " print(\"\\ n\" + \" =\" * 80)\n " ,
323335 " print(\" ✅ Training Completed Successfully!\" )\n " ,
324- " print(f\" ✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\" )\n " ,
325336 " print(\" =\" * 80)\n " ,
326- " print(f\" 📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\" )\n " ,
327- " print(f\" 📊 TensorBoard logs: {trainer_config.tensorboard_dir}\" )\n " ,
328- " print(f\" 🎯 Model ready for inference!\" )\n " ,
329- " except Exception as e:\n " ,
337+ " except Exception:\n " ,
330338 " print(\"\\ n\" + \" =\" * 80)\n " ,
331339 " print(\" ❌Training Failed!\" )\n " ,
332340 " print(\" =\" * 80)\n " ,
333- " print(f\" Error: {str(e)}\" )"
341+ " traceback.print_exc()\n " ,
342+ " sys.exit(1)"
334343 ]
335344 },
336345 {
347356 ],
348357 "metadata" : {
349358 "kernelspec" : {
350- "display_name" : " .venv " ,
359+ "display_name" : " Python 3 " ,
351360 "language" : " python" ,
352361 "name" : " python3"
353362 },
361370 "name" : " python" ,
362371 "nbconvert_exporter" : " python" ,
363372 "pygments_lexer" : " ipython3" ,
364- "version" : " 3.10.12 "
373+ "version" : " 3.12.11 "
365374 }
366375 },
367376 "nbformat" : 4 ,
368377 "nbformat_minor" : 4
369- }
378+ }
0 commit comments