Skip to content

Commit fb79a9e

Browse files
Merge pull request #3959 from AI-Hypercomputer:yixuannwang-dev-demo-nb
PiperOrigin-RevId: 919200210
2 parents 868d99b + 171f932 commit fb79a9e

3 files changed

Lines changed: 106 additions & 2 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def main(argv: Sequence[str]) -> None:
366366
argv: Command-line arguments, which are parsed by `pyconfig`.
367367
"""
368368
# Initialize maxtext config
369-
config = pyconfig.initialize(argv)
369+
config = pyconfig.initialize_pydantic(argv)
370370
assert (
371371
config.load_full_state_path == ""
372372
), "This script expects parameters, not a full state. Use generate_param_only_checkpoint first if needed."

src/maxtext/examples/rl_llama3_demo.ipynb

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
"from etils import epath\n",
107107
"import jax\n",
108108
"\n",
109+
"from maxtext.configs import pyconfig\n",
109110
"from maxtext.trainers.post_train.rl.train_rl import rl_train\n",
110111
"from maxtext.utils.model_creation_utils import setup_configs_and_devices\n",
111112
"from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n",
@@ -292,6 +293,59 @@
292293
" sys.exit(1)"
293294
]
294295
},
296+
{
297+
"cell_type": "markdown",
298+
"metadata": {},
299+
"source": [
300+
"## Convert MaxText Checkpoint to Hugging Face Format"
301+
]
302+
},
303+
{
304+
"cell_type": "code",
305+
"execution_count": null,
306+
"metadata": {},
307+
"outputs": [],
308+
"source": [
309+
"config = pyconfig.initialize_pydantic(config_argv)\n",
310+
"\n",
311+
"# Define the output directory for the Hugging Face checkpoint\n",
312+
"hf_output_directory = epath.Path(BASE_OUTPUT_DIRECTORY) / \"hf_checkpoint\"\n",
313+
"\n",
314+
"# Find the latest MaxText checkpoint\n",
315+
"checkpoint_dir = epath.Path(config.checkpoint_dir) / 'actor'\n",
316+
"step_dirs = [d.name for d in checkpoint_dir.iterdir() if d.name.isdigit() and d.is_dir()]\n",
317+
"if not step_dirs:\n",
318+
" raise ValueError(f\"No checkpoint found in {checkpoint_dir}\")\n",
319+
"latest_step = max(step_dirs, key=int)\n",
320+
"maxtext_checkpoint_path = checkpoint_dir / latest_step / \"model_params\"\n",
321+
"\n",
322+
"print(f\"Converting MaxText checkpoint from: {maxtext_checkpoint_path}\")\n",
323+
"print(f\"Saving Hugging Face checkpoint to: {hf_output_directory}\")\n",
324+
"\n",
325+
"# Run the conversion script\n",
326+
"env = os.environ.copy()\n",
327+
"env[\"JAX_PLATFORMS\"] = \"cpu\"\n",
328+
"\n",
329+
"subprocess.run(\n",
330+
" [\n",
331+
" sys.executable,\n",
332+
" \"-m\", \"maxtext.checkpoint_conversion.to_huggingface\",\n",
333+
" f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n",
334+
" f\"model_name={MODEL_NAME}\",\n",
335+
" f\"load_parameters_path={str(maxtext_checkpoint_path)}\",\n",
336+
" f\"base_output_directory={str(hf_output_directory)}\",\n",
337+
" f\"scan_layers={config.scan_layers}\",\n",
338+
" \"use_multimodal=false\",\n",
339+
" \"skip_jax_distributed_system=True\",\n",
340+
" \"weight_dtype=bfloat16\",\n",
341+
" ],\n",
342+
" check=True,\n",
343+
" env=env\n",
344+
")\n",
345+
"\n",
346+
"print(\"✓ Conversion completed successfully!\")"
347+
]
348+
},
295349
{
296350
"cell_type": "markdown",
297351
"metadata": {},

src/maxtext/examples/sft_llama3_demo_tpu.ipynb

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@
176176
"MODEL_NAME = \"llama3.1-8b-Instruct\"\n",
177177
"\n",
178178
"BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/sft_llama3_output\"\n",
179-
"\n",
180179
"# set the path to the model checkpoint (including `/0/items`) or leave empty to download from HuggingFace\n",
181180
"MODEL_CHECKPOINT_PATH = \"\"\n",
182181
"if not MODEL_CHECKPOINT_PATH:\n",
@@ -322,6 +321,57 @@
322321
" sys.exit(1)"
323322
]
324323
},
324+
{
325+
"cell_type": "markdown",
326+
"metadata": {},
327+
"source": [
328+
"## Convert MaxText Checkpoint to Hugging Face Format"
329+
]
330+
},
331+
{
332+
"cell_type": "code",
333+
"execution_count": null,
334+
"metadata": {},
335+
"outputs": [],
336+
"source": [
337+
"# Define the output directory for the Hugging Face checkpoint\n",
338+
"hf_output_directory = epath.Path(BASE_OUTPUT_DIRECTORY) / \"hf_checkpoint\"\n",
339+
"\n",
340+
"# Find the latest MaxText checkpoint\n",
341+
"checkpoint_dir = epath.Path(config.checkpoint_dir)\n",
342+
"step_dirs = [d.name for d in checkpoint_dir.iterdir() if d.name.isdigit() and d.is_dir()]\n",
343+
"if not step_dirs:\n",
344+
" raise ValueError(f\"No checkpoint found in {checkpoint_dir}\")\n",
345+
"latest_step = max(step_dirs, key=int)\n",
346+
"maxtext_checkpoint_path = checkpoint_dir / latest_step / \"model_params\"\n",
347+
"\n",
348+
"print(f\"Converting MaxText checkpoint from: {maxtext_checkpoint_path}\")\n",
349+
"print(f\"Saving Hugging Face checkpoint to: {hf_output_directory}\")\n",
350+
"\n",
351+
"# Run the conversion script\n",
352+
"env = os.environ.copy()\n",
353+
"env[\"JAX_PLATFORMS\"] = \"cpu\"\n",
354+
"\n",
355+
"subprocess.run(\n",
356+
" [\n",
357+
" sys.executable,\n",
358+
" \"-m\", \"maxtext.checkpoint_conversion.to_huggingface\",\n",
359+
" f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n",
360+
" f\"model_name={MODEL_NAME}\",\n",
361+
" f\"load_parameters_path={str(maxtext_checkpoint_path)}\",\n",
362+
" f\"base_output_directory={str(hf_output_directory)}\",\n",
363+
" f\"scan_layers={config.scan_layers}\",\n",
364+
" \"use_multimodal=false\",\n",
365+
" \"skip_jax_distributed_system=True\",\n",
366+
" \"weight_dtype=bfloat16\",\n",
367+
" ],\n",
368+
" check=True,\n",
369+
" env=env\n",
370+
")\n",
371+
"\n",
372+
"print(\"✓ Conversion completed successfully!\")"
373+
]
374+
},
325375
{
326376
"cell_type": "markdown",
327377
"metadata": {

0 commit comments

Comments
 (0)