|
106 | 106 | "from etils import epath\n", |
107 | 107 | "import jax\n", |
108 | 108 | "\n", |
| 109 | + "from maxtext.configs import pyconfig\n", |
109 | 110 | "from maxtext.trainers.post_train.rl.train_rl import rl_train\n", |
110 | 111 | "from maxtext.utils.model_creation_utils import setup_configs_and_devices\n", |
111 | 112 | "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", |
|
292 | 293 | " sys.exit(1)" |
293 | 294 | ] |
294 | 295 | }, |
| 296 | + { |
| 297 | + "cell_type": "markdown", |
| 298 | + "metadata": {}, |
| 299 | + "source": [ |
| 300 | + "## Convert MaxText Checkpoint to Hugging Face Format" |
| 301 | + ] |
| 302 | + }, |
| 303 | + { |
| 304 | + "cell_type": "code", |
| 305 | + "execution_count": null, |
| 306 | + "metadata": {}, |
| 307 | + "outputs": [], |
| 308 | + "source": [ |
| 309 | + "config = pyconfig.initialize_pydantic(config_argv)\n", |
| 310 | + "\n", |
| 311 | + "# Define the output directory for the Hugging Face checkpoint\n", |
| 312 | + "hf_output_directory = epath.Path(BASE_OUTPUT_DIRECTORY) / \"hf_checkpoint\"\n", |
| 313 | + "\n", |
| 314 | + "# Find the latest MaxText checkpoint\n", |
| 315 | + "checkpoint_dir = epath.Path(config.checkpoint_dir) / 'actor'\n", |
| 316 | + "step_dirs = [d.name for d in checkpoint_dir.iterdir() if d.name.isdigit() and d.is_dir()]\n", |
| 317 | + "if not step_dirs:\n", |
| 318 | + " raise ValueError(f\"No checkpoint found in {checkpoint_dir}\")\n", |
| 319 | + "latest_step = max(step_dirs, key=int)\n", |
| 320 | + "maxtext_checkpoint_path = checkpoint_dir / latest_step / \"model_params\"\n", |
| 321 | + "\n", |
| 322 | + "print(f\"Converting MaxText checkpoint from: {maxtext_checkpoint_path}\")\n", |
| 323 | + "print(f\"Saving Hugging Face checkpoint to: {hf_output_directory}\")\n", |
| 324 | + "\n", |
| 325 | + "# Run the conversion script\n", |
| 326 | + "env = os.environ.copy()\n", |
| 327 | + "env[\"JAX_PLATFORMS\"] = \"cpu\"\n", |
| 328 | + "\n", |
| 329 | + "subprocess.run(\n", |
| 330 | + " [\n", |
| 331 | + " sys.executable,\n", |
| 332 | + " \"-m\", \"maxtext.checkpoint_conversion.to_huggingface\",\n", |
| 333 | + " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", |
| 334 | + " f\"model_name={MODEL_NAME}\",\n", |
| 335 | + " f\"load_parameters_path={str(maxtext_checkpoint_path)}\",\n", |
| 336 | + " f\"base_output_directory={str(hf_output_directory)}\",\n", |
| 337 | + " f\"scan_layers={config.scan_layers}\",\n", |
| 338 | + " \"use_multimodal=false\",\n", |
| 339 | + " \"skip_jax_distributed_system=True\",\n", |
| 340 | + " \"weight_dtype=bfloat16\",\n", |
| 341 | + " ],\n", |
| 342 | + " check=True,\n", |
| 343 | + " env=env\n", |
| 344 | + ")\n", |
| 345 | + "\n", |
| 346 | + "print(\"✓ Conversion completed successfully!\")" |
| 347 | + ] |
| 348 | + }, |
295 | 349 | { |
296 | 350 | "cell_type": "markdown", |
297 | 351 | "metadata": {}, |
|
0 commit comments