diff --git a/docs/tutorials/posttraining/lora.md b/docs/tutorials/posttraining/lora.md new file mode 100644 index 0000000000..6d6aa0d304 --- /dev/null +++ b/docs/tutorials/posttraining/lora.md @@ -0,0 +1,215 @@ + + +# LoRA Fine-tuning on single-host TPUs + +**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption. + +Unlike traditional full-parameter fine-tuning, LoRA: + +- **Freezes the pre-trained model weights**, preserving the original knowledge. +- **Injects trainable rank decomposition matrices** into the Transformer layers. + +This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient. + +This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText. + +We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks. + +In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! + +**Note:** Since **qwix** support has been recently integrated into the **main branch**, you must **clone** the latest source code and install it in **editable mode** to ensure all dependencies are correctly linked. + +```sh +# Install Qwix from source +git clone https://github.com/google/qwix.git +cd qwix +uv pip install -e . +``` + +## Setup environment variables + +Set the following environment variables before running LoRA Fine-tuning. + +```sh +# -- Model configuration -- +export PRE_TRAINED_MODEL= # e.g., 'gemma3-4b' + +# -- MaxText configuration -- +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory +export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) +export STEPS= # e.g., 1000 +export PER_DEVICE_BATCH_SIZE= # e.g., 1 +export HF_TOKEN= +export LORA_RANK= # e.g., 16 +export LORA_ALPHA= # e.g., 32.0 +export LEARNING_RATE= # e.g., 3e-6 +export MAX_TARGET_LENGTH= # e.g., 1024 +export WEIGHT_DTYPE= # e.g., bfloat16 +export DTYPE= # e.g., bfloat16 + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., openai/gsm8k +export TRAIN_SPLIT= # e.g., train +export HF_DATA_DIR= # e.g., main +export TRAIN_DATA_COLUMNS= # e.g., ['question','answer'] + +# -- LoRA Conversion configuration (Optional) -- +export HF_LORA_ADAPTER_PATH= # e.g., 'username/adapter-name' +``` + +## Customizing Trainable Layers (Optional) + +By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading `src/maxtext/configs/post_train/lora_module_path.yml`. + +If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy: + +### Configuration Hierarchy + +1. **Command Line Argument**: Pass the `lora_module_path` argument directly in your training command. This is the most flexible way for experimental iterations. +2. **Task-Specific Config (`sft.yml`)**: Define the `lora_module_path` parameter in `src/maxtext/configs/post_train/sft.yml` to set a persistent configuration for your SFT runs. +3. **Global Defaults**: Automatic detection via the model-to-regex mapping defined in `lora_module_path.yml`. + +## Get your model checkpoint + +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. + +### Option 1: Using an existing MaxText checkpoint + +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. + +```sh +export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +``` + +### Option 2: Converting a Hugging Face checkpoint + +Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. + +```sh +export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +``` + +## Run a Fresh LoRA Fine-Tuning on Hugging Face Dataset + +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. + +Execute the following command to begin training: + +```sh +python3 -m maxtext.trainers.post_train.sft.train_sft \ + run_name="${RUN_NAME?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + model_name="${PRE_TRAINED_MODEL?}" \ + load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \ + hf_access_token="${HF_TOKEN?}" \ + hf_path="${DATASET_NAME?}" \ + train_split="${TRAIN_SPLIT?}" \ + hf_data_dir="${HF_DATA_DIR?}" \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps="${STEPS?}" \ + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ + max_target_length="${MAX_TARGET_LENGTH?}" \ + learning_rate="${LEARNING_RATE?}" \ + weight_dtype="${WEIGHT_DTYPE?}" \ + dtype="${DTYPE?}" \ + enable_nnx=True \ + pure_nnx_decoder=True \ + enable_lora=True \ + lora_rank="${LORA_RANK?}" \ + lora_alpha="${LORA_ALPHA?}" \ + scan_layers=True +``` + +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. + +## (Optional) Resume from a previous LoRA checkpoint + +If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path. + +### Step 1: Convert HF LoRA adapter to MaxText format + +If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the provided conversion script: + +```sh +python3 -m maxtext.checkpoint_conversion.hf_lora_to_maxtext \ + model_name="${PRE_TRAINED_MODEL?}" \ + hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + scan_layers=True +``` + +### Step 2: Set the restore path + +Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files). + +- **load_parameters_path**: Points to the frozen base model weights (the original model). +- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load. + +```sh +export LORA_RESTORE_PATH= # e.g., gs://my-bucket/run-1/checkpoints/0/items +``` + +### Step 3: Run LoRA Fine-Tuning with the Restore Path + +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. + +Execute the following command to begin training: + +```sh +python3 -m maxtext.trainers.post_train.sft.train_sft \ + run_name="${RUN_NAME?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + model_name="${PRE_TRAINED_MODEL?}" \ + load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \ + lora_restore_path="${LORA_RESTORE_PATH}" \ + hf_access_token="${HF_TOKEN?}" \ + hf_path="${DATASET_NAME?}" \ + train_split="${TRAIN_SPLIT?}" \ + hf_data_dir="${HF_DATA_DIR?}" \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps="${STEPS?}" \ + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ + max_target_length="${MAX_TARGET_LENGTH?}" \ + learning_rate="${LEARNING_RATE?}" \ + weight_dtype="${WEIGHT_DTYPE?}" \ + dtype="${DTYPE?}" \ + enable_nnx=True \ + pure_nnx_decoder=True \ + enable_lora=True \ + lora_rank="${LORA_RANK?}" \ + lora_alpha="${LORA_ALPHA?}" \ + scan_layers=True +``` + +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. + +## (Optional) Convert Fine-tuned LoRA to Hugging Face Format + +After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `maxtext_lora_to_hf.py` script. + +```sh +python3 -m maxtext.checkpoint_conversion.maxtext_to_hf_lora \ + model_name="${PRE_TRAINED_MODEL?}" \ + load_parameters_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints//model_params" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter" \ + lora_rank="${LORA_RANK?}" \ + lora_alpha="${LORA_ALPHA?}" +``` + +- `load_parameters_path`: Point this to the specific checkpoint directory (e.g., `.../checkpoints/1000/items`) that you want to export. +- `base_output_directory`: The local or GCS directory where the Hugging Face `adapter_model.safetensors` and `adapter_config.json` will be saved. +- `lora_rank` / `lora_alpha`: Must match the values used during the training phase to ensure the `adapter_config.json` is generated correctly. diff --git a/src/maxtext/checkpoint_conversion/hf_lora_to_maxtext.py b/src/maxtext/checkpoint_conversion/hf_lora_to_maxtext.py new file mode 100644 index 0000000000..37548e72ec --- /dev/null +++ b/src/maxtext/checkpoint_conversion/hf_lora_to_maxtext.py @@ -0,0 +1,284 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a HuggingFace LoRA adapter to MaxText LoRA adapter format. + +Key Parameters (to be set in the config file or as command-line overrides): + model_name: (Required) The name of the model (e.g., "llama3.1-8b"). + base_output_directory: (Required) The directory where the MaxText LoRA adapter + will be saved. Can be set in config file or as command-line override. + hf_lora_adapter_path: (Required) Path to the HF LoRA adapter directory or HuggingFace repo ID. + scan_layers: (bool) Whether the MaxText model uses scanned layers. + This must match the training configuration. + hf_access_token: (Optional) HuggingFace authentication token if needed for adapter. + +Example Usage: + To convert HF LoRA to MaxText adapter: + + python src/maxtext/checkpoint_conversion/hf_lora_to_maxtext.py \ + MaxText/configs/sft.yml model_name="llama3.1-8b" \ + hf_lora_adapter_path="username/lora-adapter-repo" \ + base_output_directory="/path/to/output/directory" \ + scan_layers=False +""" + +import argparse +import json +import os +import shutil +import sys +from typing import Sequence + +import jax +import jax.numpy as jnp +from etils import epath +from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files +from safetensors import safe_open +from transformers import AutoConfig + +from orbax import checkpoint as ocp +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING +from maxtext.utils.globals import HF_IDS +from maxtext.configs import pyconfig +from maxtext.utils import max_logging +from absl import logging + + +def load_hf_lora_adapter(adapter_path: str, hf_model_id: str, hf_access_token: str | None = None) -> dict: + """Load HF LoRA adapter weights directly from safetensors files.""" + max_logging.log(f"Loading HF LoRA adapter from {adapter_path}") + + # Check adapter compatibility + adapter_config = None + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + config_file = adapter_dir / "adapter_config.json" + if config_file.exists(): + with open(config_file, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + else: + # HF Hub repo + try: + config_file = hf_hub_download(adapter_path, "adapter_config.json", token=hf_access_token) + with open(config_file, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + except Exception as exc: # pylint: disable=broad-exception-caught + max_logging.log(f"Warning: Could not load adapter_config.json from HF Hub: {exc}") + + if adapter_config: + if adapter_config.get("base_model_name_or_path"): + max_logging.log(f"Adapter base model: {adapter_config['base_model_name_or_path']}") + max_logging.log(f"Adapter compatible with model {hf_model_id}") + + # Handle both local paths and HF Hub paths + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + adapter_files = list(adapter_dir.glob("*.safetensors")) + if not adapter_files: + adapter_files = list(adapter_dir.glob("*.bin")) + if not adapter_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = adapter_files[0] + else: + # Assume it's a HF Hub repo ID + try: + files = list_repo_files(adapter_path, token=hf_access_token) + safetensor_files = [f for f in files if f.endswith(".safetensors")] + if not safetensor_files: + bin_files = [f for f in files if f.endswith(".bin")] + if not bin_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = bin_files[0] + else: + adapter_file = safetensor_files[0] + + # Download the adapter file + adapter_file = hf_hub_download(adapter_path, adapter_file, token=hf_access_token) + except Exception as e: + raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") from e + + adapter_file_path = str(adapter_file) + + # Load the adapter weights + if adapter_file_path.endswith(".safetensors"): + with safe_open(adapter_file_path, framework="numpy") as f: + lora_weights = {k: f.get_tensor(k) for k in f.keys()} + else: + raise ValueError(f"Unsupported adapter file format: {adapter_file_path}") + + max_logging.log(f"Loaded {len(lora_weights)} LoRA parameters from adapter") + return lora_weights + + +def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict) -> str: + """Convert HF LoRA key to MaxText parameter path using the mapping from to_maxtext.py.""" + # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight + + # Clean up LoRA suffixes to get the base module path + # e.g. ...q_proj.lora_A.weight -> ...q_proj + hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "") + hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") + + # Handle prefix. Expected target is usually "model.layers..." + # Input could be "base_model.model.model.layers..." or "base_model.model.layers..." + if hf_param_key.startswith("base_model.model."): + hf_param_key = hf_param_key[len("base_model.model.") :] + + for mt_key, hf_keys in param_mapping.items(): + if isinstance(hf_keys, list): + for hf_k in hf_keys: + if hf_k.replace(".weight", "") == hf_param_key: + return mt_key + elif isinstance(hf_keys, str): + if hf_keys.replace(".weight", "") == hf_param_key: + return mt_key + + return None + + +def convert_lora_to_maxtext_adapter( + config, + lora_weights: dict, + output_path: str, + hf_model_id: str, + hf_access_token: str | None = None, +): + """Converts HF LoRA weights to MaxText adapter format without merging.""" + + # Get the parameter mapping (MT -> HF) + model_key = config.model_name + if "-Instruct" in model_key: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_key = model_key.replace("-Instruct", "") + hf_config_obj = AutoConfig.from_pretrained(hf_model_id, token=hf_access_token) + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) + + mt_adapter_tree = {} + mapped_count = 0 + + # Map HF LoRA weights to MaxText keys + for hf_key, weight in lora_weights.items(): + mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf) + + if mt_key: + suffix = "lora_a" if "lora_A" in hf_key or "lora_a" in hf_key else "lora_b" + + # NNX expects LoRA parameters to be direct children of the module, + # but PARAM_MAPPING keys often end in "-kernel". + mt_key = mt_key.replace("-kernel", "") + + # Construct a nested dictionary path in mt_adapter_tree + parts = mt_key.split("-") + if parts[0] == "params": + parts = parts[1:] + current = mt_adapter_tree + for part in parts: + if part not in current: + current[part] = {} + current = current[part] + + current[suffix] = jnp.array(weight) + mapped_count += 1 + else: + max_logging.log(f"Warning: Could not map HF LoRA key {hf_key} to MaxText key") + + max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") + + # Save as a standalone adapter checkpoint + max_logging.log(f"Saving MaxText LoRA adapter to {output_path}") + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + ckptr.save(epath.Path(output_path), mt_adapter_tree) + + max_logging.log("LoRA adapter conversion completed successfully") + + +def main(args: Sequence[str]) -> None: + # Set logging to INFO level to see max_logging.log messages + logging.set_verbosity(logging.INFO) + + # Check if the user is using an Instruct version. If so, use the base model architecture + original_model_name = None + for i, arg in enumerate(args): + if arg.startswith("model_name="): + model_name_arg = args[i].split("=")[1] + # Remove quotes if present + model_name_arg = model_name_arg.strip("'").strip('"') + original_model_name = model_name_arg + + if "-Instruct" in model_name_arg: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_name_arg = model_name_arg.replace("-Instruct", "") + args[i] = f"model_name={model_name_arg}" + break + + # Initialize maxtext config + config = pyconfig.initialize_pydantic(args) + + if not hasattr(config, "hf_lora_adapter_path") or not config.hf_lora_adapter_path: + raise ValueError("hf_lora_adapter_path must be specified") + + # Determine HF model ID and check if supported + hf_model_id = HF_IDS.get(config.model_name) + if hf_model_id is None: + raise ValueError(f"Model '{config.model_name}' is not supported. Use a supported model_name from HF_IDS.") + + if not hasattr(config, "base_output_directory") or not config.base_output_directory: + raise ValueError("base_output_directory must be specified (in config file or as command-line argument)") + + output_dir = config.base_output_directory + + # Use original model name for output path + model_name_for_path = original_model_name or config.model_name + adapter_name = os.path.basename(os.path.normpath(config.hf_lora_adapter_path)) + full_output_path = os.path.join(output_dir, model_name_for_path, adapter_name) + + os.makedirs(os.path.dirname(full_output_path), exist_ok=True) + + if os.path.exists(full_output_path): + max_logging.log(f"Output directory {full_output_path} exists. Removing it to allow Orbax to save.") + shutil.rmtree(full_output_path) + + # Load LoRA adapter and check compatibility + hf_access_token = config.hf_access_token + lora_weights = load_hf_lora_adapter(config.hf_lora_adapter_path, hf_model_id, hf_access_token) + + # Convert LoRA to MaxText adapter format and save + convert_lora_to_maxtext_adapter(config, lora_weights, full_output_path, hf_model_id, hf_access_token) + + # Verify output was created + outputpath = epath.Path(full_output_path) + if not outputpath.exists(): + raise RuntimeError(f"Failed to create output directory {full_output_path}") + + +if __name__ == "__main__": + # Argument parsing similar to to_maxtext.py + parser = argparse.ArgumentParser() + parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) + + # Parse local arguments + local_args, remaining_args = parser.parse_known_args() + + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set jax environment + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" + + main(model_args) diff --git a/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py new file mode 100644 index 0000000000..ab7043c841 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py @@ -0,0 +1,173 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a MaxText LoRA adapter (checkpoint) back to HuggingFace PEFT format. + +Key Parameters (to be set in the config file or as command-line overrides): + model_name: (Required) The name of the model (e.g., "llama3.1-8b"). + load_parameters_path: (Required) Path to the MaxText checkpoint directory. + base_output_directory: (Required) The directory where the HuggingFace adapter will be saved. + lora_rank: The rank of the LoRA adapter. + lora_alpha: The alpha parameter for LoRA. + +Example Usage: + python src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py \ + src/maxtext/configs/base.yml \ + model_name="llama3.1-8b" \ + load_parameters_path="maxtext/lora/ckpt_path/" \ + base_output_directory="output/path/" \ + lora_rank=16 \ + lora_alpha=32 +""" + +import argparse +import jax +import os +import json +import numpy as np +import sys +from safetensors.numpy import save_file +from orbax import checkpoint as ocp +from etils import epath +from transformers import AutoConfig +from maxtext.configs import pyconfig +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING +from maxtext.utils.globals import HF_IDS +from maxtext.utils import max_logging + + +def convert(argv): + """Converts a MaxText LoRA adapter checkpoint to HuggingFace PEFT format.""" + config = pyconfig.initialize_pydantic(argv) + + if not hasattr(config, "load_parameters_path") or not config.load_parameters_path: + raise ValueError("load_parameters_path must be specified") + + if not hasattr(config, "base_output_directory") or not config.base_output_directory: + raise ValueError("base_output_directory must be specified") + + maxtext_ckpt_path = config.load_parameters_path + output_dir = config.base_output_directory + model_name = config.model_name + lora_r = config.lora.lora_rank + lora_alpha = config.lora.lora_alpha + + hf_model_id = HF_IDS.get(config.model_name) + if hf_model_id is None: + raise ValueError(f"Model '{config.model_name}' is not supported. Use a supported model_name from HF_IDS.") + + max_logging.log(f"Starting conversion for model: {model_name}") + max_logging.log(f"Path: {maxtext_ckpt_path}") + + mapping_model_name = config.model_name.replace("-Instruct", "") + + # Initialize Orbax Checkpointer + mngr = ocp.PyTreeCheckpointer() + mt_params = mngr.restore(epath.Path(maxtext_ckpt_path)) + + # Load HF Config for mapping + hf_config = AutoConfig.from_pretrained(hf_model_id).to_dict() + + # Get the parameter mapping for the specific model + if mapping_model_name not in PARAM_MAPPING: + raise ValueError(f"Model {mapping_model_name} not found in PARAM_MAPPING") + + mapping = PARAM_MAPPING[mapping_model_name](hf_config, config, config.scan_layers) + + final_hf_weights = {} + found_hf_modules = set() + + def process_data(current_dict, parent_path="decoder/layers"): + """Recursive function to traverse MaxText params and map to HF.""" + for module_name, content in current_dict.items(): + path = f"{parent_path}/{module_name}" + + # Identify LoRA layers + if isinstance(content, dict) and "kernel_lora_a" in content: + lookup_key = "params-" + path.replace("/", "-") + "-kernel" + + if lookup_key in mapping: + # Get the JAX values (as numpy) + data_a = np.array(content["kernel_lora_a"]["value"]) + data_b = np.array(content["kernel_lora_b"]["value"]) + hf_paths = mapping[lookup_key] + + if not isinstance(hf_paths, list): + hf_paths = [hf_paths] + + # MaxText stacks multiple heads/projections, iterate through them + for i in range(min(data_a.shape[1], len(hf_paths))): + full_hf_path = hf_paths[i] + + module_type = full_hf_path.split(".")[-2] + found_hf_modules.add(module_type) + + name = hf_paths[i].replace(".weight", "") + # Apply Transpose (.T) to match PyTorch dimension logic + final_hf_weights[f"base_model.model.{name}.lora_A.weight"] = data_a[:, i, :].T + final_hf_weights[f"base_model.model.{name}.lora_B.weight"] = data_b[:, i, :].T + + max_logging.log(f"Mapped {lookup_key} to {len(hf_paths)} HF layers") + + elif isinstance(content, dict): + process_data(content, path) + + # Start recursion + start_node = mt_params.get("decoder", {}).get("layers", mt_params) + process_data(start_node) + + # Save Safetensors + os.makedirs(output_dir, exist_ok=True) + adapter_file = os.path.join(output_dir, "adapter_model.safetensors") + save_file(final_hf_weights, adapter_file) + + # Create PEFT adapter_config.json + config_json = { + "base_model_name_or_path": hf_model_id, + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": int(lora_r), + "lora_alpha": int(lora_alpha), + "target_modules": list(found_hf_modules), + "lora_dropout": 0.0, + "bias": "none", + "inference_mode": True, + } + + config_file = os.path.join(output_dir, "adapter_config.json") + with open(config_file, "w", encoding="utf-8") as f: + json.dump(config_json, f, indent=4) + + max_logging.log("Conversion Complete!") + max_logging.log(f"Saved weights to: {adapter_file}") + max_logging.log(f"Saved config to: {config_file}") + max_logging.log(f"Target modules detected: {list(found_hf_modules)}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) + + # Parse local arguments + local_args, remaining_args = parser.parse_known_args() + + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set jax environment + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" + + convert(model_args) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 4671d76d4c..48b7613562 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -43,6 +43,7 @@ load_parameters_path: "" # LoRA adapter support configs lora_input_adapters_path: "" # Input GCS path for a parent directory which has all the LoRA adapters (lora_id as subdir) +hf_lora_adapter_path: "" # Input HF repo ID or local path for HF LoRA adapter # Loads a full checkpoint including optimizer state and step count from a specific directory # e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items diff --git a/src/maxtext/configs/post_train/lora_module_path.yml b/src/maxtext/configs/post_train/lora_module_path.yml new file mode 100644 index 0000000000..11f81d52c5 --- /dev/null +++ b/src/maxtext/configs/post_train/lora_module_path.yml @@ -0,0 +1,28 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Recommended LoRA module paths by model architecture prefix. +# These models have been explicitly tested and verified for LoRA. + +llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)" +mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" +gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" +gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" +olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))" + +# Fallback for unverified models +default: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 32c86ddb31..10c71324e9 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -21,6 +21,18 @@ sft_train_on_completion_only: True packing: True learning_rate: 2.e-5 +# -------------- LoRA / QLoRA -------------- +lora: + enable_lora: False + lora_rank: 0 + lora_alpha: 0.0 + lora_module_path: "" + # For QLoRA (TODO: not working), set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size. + lora_weight_qtype: null # TODO: not working + lora_tile_size: null # TODO: not working + # Optional path to LoRA weights to load before training. Ignored if the current run is resumed. + lora_restore_path: "" + # -------------- HF pipeline -------------- dataset_type: hf hf_path: 'HuggingFaceH4/ultrachat_200k' diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index e300367eab..f8f29edf27 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -64,6 +64,8 @@ "maxtext.inference.vllm_decode": "base.yml", "maxtext.checkpoint_conversion.to_maxtext": "base.yml", "maxtext.checkpoint_conversion.to_huggingface": "base.yml", + "maxtext.checkpoint_conversion.maxtext_to_hf_lora": "base.yml", + "maxtext.checkpoint_conversion.hf_lora_to_maxtext": "base.yml", } diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 2bfe9d6b31..d3b0f4382b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -304,6 +304,13 @@ class Checkpointing(BaseModel): load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.") lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") + hf_lora_adapter_path: PathStr = Field( + "", + description=( + "HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local " + "path to directory containing adapter_model.safetensors." + ), + ) load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") @@ -341,7 +348,8 @@ class Checkpointing(BaseModel): description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.", ) enable_autocheckpoint: bool = Field( - False, description="If True, enables autocheckpoint or preemption induced checkpointing." + False, + description="If True, enables autocheckpoint or preemption induced checkpointing.", ) @@ -469,7 +477,8 @@ class ModelArchitecture(BaseModel): ) fused_mlp: bool = Field(False, description="If supported, fuse the MLP layers.") qk_norm_with_scale: bool = Field( - True, description="Whether to apply scale on query and key normalizations (default True)." + True, + description="Whether to apply scale on query and key normalizations (default True).", ) v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).") @@ -516,9 +525,13 @@ class Attention(BaseModel): "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( - False, description="If True, for global attention, Key and Value projections share the same weights." + False, + description="If True, for global attention, Key and Value projections share the same weights.", + ) + global_num_kv_heads: int = Field( + 0, + description="If greater than 0, sets the number of KV heads for global attention.", ) - global_num_kv_heads: int = Field(0, description="If greater than 0, sets the number of KV heads for global attention.") attention_sink: bool = Field(False, description="If True, enables attention sinks.") float32_qk_product: bool = Field(False, description="In dot-product attention, cast query-key product to fp32.") float32_logits: bool = Field( @@ -1006,7 +1019,8 @@ class Tokenizer(BaseModel): use_chat_template: bool = Field(False, description="Whether to use the chat template for tokenization.") chat_template_path: str = Field("", description="Path to chat template json file.") chat_template: str = Field( - "", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template." + "", + description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.", ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") @@ -1098,7 +1112,8 @@ class GrainDataset(BaseModel): description="Path to a JSON file specifying the mixture weights for Grain training data.", ) grain_file_type: str = Field( - "arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet." + "arrayrecord", + description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet.", ) grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.") grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.") @@ -1131,6 +1146,32 @@ class FineTuning(BaseModel): use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") +class LoRA(BaseModel): + """Configuration for LoRA / QLoRA adapters.""" + + enable_lora: bool = Field(False, description="If True, enables LoRA/QLoRA during fine-tuning.") + lora_rank: NonNegativeInt = Field(0, description="LoRA rank. Set >0 when LoRA is enabled.") + lora_alpha: NonNegativeFloat = Field(0.0, description="LoRA alpha scaling factor.") + lora_module_path: str = Field( + "", + description=( + "Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + ), + ) + lora_weight_qtype: str | None = Field( + None, + description=("TODO: not working. Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."), + ) + lora_tile_size: NonNegativeInt | None = Field( + None, + description="TODO: not working. Optional tile size for QLoRA (e.g., 128 or 256).", + ) + lora_restore_path: PathStr = Field( + "", + description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."), + ) + + class Distillation(BaseModel): """Configuration for Knowledge Distillation.""" @@ -1148,7 +1189,8 @@ class Distillation(BaseModel): # --- Offline Distillation Field --- offline_data_dir: Optional[str] = Field( - None, description="GCS or local path to the pre-generated ArrayRecord teacher data." + None, + description="GCS or local path to the pre-generated ArrayRecord teacher data.", ) # --- Loss Params --- @@ -1156,7 +1198,8 @@ class Distillation(BaseModel): distill_temperature: float = Field(1.0, description="Temperature for distillation softening.") distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable") distill_feature_loss_type: Literal["cosine", "l2"] = Field( - "cosine", description="The type of loss to use for feature distillation ('cosine' or 'l2')." + "cosine", + description="The type of loss to use for feature distillation ('cosine' or 'l2').", ) distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.") @@ -1219,10 +1262,12 @@ class Optimizer(BaseModel): opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.") skip_step_on_spikes: bool = Field( - False, description="If True, skip the training step when loss or gradient spike is detected." + False, + description="If True, skip the training step when loss or gradient spike is detected.", ) skip_step_interval: PositiveInt = Field( - 128, description="The rolling interval to calculate the mean and standard deviation." + 128, + description="The rolling interval to calculate the mean and standard deviation.", ) skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.") gradient_accumulation_steps: PositiveInt = Field( @@ -1661,7 +1706,10 @@ class VisionTower(BaseModel): temporal_patch_size_for_vit: int = Field(2, description="Temporal patch size for video inputs.") num_position_embeddings_for_vit: int = Field(1024, description="Number of position embeddings for ViT.") deepstack_visual_indexes_for_vit: list[int] = Field([], description="Layer indices to extract deep visual features.") - vision_output_length: int = Field(-1, description="The output length (number of soft tokens) from the vision encoder.") + vision_output_length: int = Field( + -1, + description="The output length (number of soft tokens) from the vision encoder.", + ) class VisionProjector(BaseModel): @@ -1751,18 +1799,28 @@ class RL(BaseModel): grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.") loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.") use_agentic_rollout: bool = Field( - False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts." + False, + description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.", + ) + max_concurrency: int = Field( + 256, + description="Maximum number of concurrent rollout requests (agentic rollout only).", ) - max_concurrency: int = Field(256, description="Maximum number of concurrent rollout requests (agentic rollout only).") off_policy_steps: int = Field( - 0, description="Number of off-policy steps tolerated before requiring a policy update (agentic only)." + 0, + description="Number of off-policy steps tolerated before requiring a policy update (agentic only).", + ) + system_prompt: str = Field( + "", + description="System prompt injected into the agent at rollout time (agentic only).", ) - system_prompt: str = Field("", description="System prompt injected into the agent at rollout time (agentic only).") degenerate_group_masking: bool = Field( - True, description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)." + True, + description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only).", ) epsilon_high: Optional[float] = Field( - None, description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)." + None, + description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only).", ) @@ -2084,6 +2142,10 @@ class MaxTextConfig( default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).", ) + lora: LoRA = Field( + default_factory=LoRA, + description="Configuration for LoRA / QLoRA adapters.", + ) model_config = ConfigDict(extra="forbid", protected_namespaces=()) @model_validator(mode="before") diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 3b0de8e0da..f402a1a0e0 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -35,6 +35,7 @@ MODEL_MODE_TRAIN, Config, DecoderBlockType, + MultimodalInput, ShardMode, ) from maxtext.inference import page_manager @@ -432,54 +433,236 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs) """Runs the layer stack using nnx.scan.""" policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) - graphdef, params, state = nnx.split( - layers, nnx.Param, ... - ) # state: the mutable state we carry (KV cache, RNGs, etc.) - scan_axis = self.config.param_scan_axis + + graphdef, params, state = nnx.split(layers, nnx.Param, ...) + if scan_axis != 0: - # Move scan_axis to 0 so scan can iterate over it params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) - layer_cls = layers.__class__ - sig = inspect.signature(layer_cls.__call__) + sig = inspect.signature(layers.__class__.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + dynamic_graph_init = bool(getattr(self, "disable_quant_stats_update", False)) + updated_graphdef = [graphdef] def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars + def rank_consistent_spec(spec, shape): + if spec is None: return None + spec_list = list(spec) + + # 1. Remove scanning axes if rank reduction is needed + if len(spec_list) > len(shape): + for axis_name in ["layers", "stage"]: + if axis_name in spec_list: + spec_list.remove(axis_name) + if len(spec_list) == len(shape): break + + # 2. If still mismatched, strip from the left (standard JAX rank reduction) + while len(spec_list) > len(shape): + spec_list.pop(0) + + # 3. If rank is too small, pad with None + while len(spec_list) < len(shape): + spec_list.insert(0, None) + + return jax.sharding.PartitionSpec(*spec_list) + + def fix_node_rank(x): + if hasattr(x, "get_metadata") and hasattr(x, "replace") and hasattr(x, "value"): + metadata = x.get_metadata() + updates = {} + for k, axes in metadata.items(): + if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)): + # Convert tuple/list to spec for check + spec_obj = jax.sharding.PartitionSpec(*axes) if isinstance(axes, (tuple, list)) else axes + if len(spec_obj) != x.value.ndim: + new_spec = rank_consistent_spec(spec_obj, x.value.shape) + # Keep original type (tuple vs spec) + updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec + # print(f"[DEBUG] Normalizing metadata key '{k}' from rank {len(spec_obj)} to {len(new_spec)}") + if updates: + return x.replace(**updates) + return x + + is_nnx_var = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace") + current_params = jax.tree.map(fix_node_rank, current_params, is_leaf=is_nnx_var) + current_state = jax.tree.map(fix_node_rank, current_state, is_leaf=is_nnx_var) + if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) - # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) - new_current_state = nnx.state(layer) - return new_carry, new_current_state + # Extract EVERYTHING to capture new parameters + new_graphdef, updated_params, updated_state = nnx.split(layer, nnx.Param, ...) + + if dynamic_graph_init: + updated_graphdef[0] = new_graphdef + returned_params = updated_params + else: + returned_params = current_params + + return new_carry, (returned_params, updated_state) + + if dynamic_graph_init: + print(f"[DEBUG] Starting Dynamic Graph Init Loop (length={length})") + curr_carry = x_in + out_params_list = [] + out_other_list = [] + + def _slice_and_unpromote(x, i): + # Resolve physical value and shape + is_var = hasattr(x, "get_metadata") and hasattr(x, "replace") + val = x.value if is_var else x + + if not hasattr(val, "shape") or len(val.shape) == 0 or val.shape[0] != length: + return x + + # 1. Slice value + sliced_val = val[i] + + # 2. Slice logical metadata if it's an NNX variable + if is_var: + metadata = x.get_metadata() + updates = {} + for sharding_key in ["sharding", "out_sharding", "sharding_names"]: + axes = metadata.get(sharding_key) + if isinstance(axes, jax.sharding.PartitionSpec): + spec_list = list(axes) + + # Aggressively reduce rank to match sliced_val.ndim + for axis_to_remove in ["layers", "stage"]: + if axis_to_remove in spec_list and len(spec_list) > sliced_val.ndim: + spec_list.remove(axis_to_remove) + + while len(spec_list) > sliced_val.ndim: + spec_list.pop(0) + + while len(spec_list) < sliced_val.ndim: + spec_list.insert(0, None) + + new_spec = jax.sharding.PartitionSpec(*spec_list) + updates[sharding_key] = new_spec + + return x.replace(value=sliced_val, **updates) + + return sliced_val + + def _promote_to_scanned(x): + """Adds 'layers' axis back to newly created parameters if scanning is enabled.""" + if not self.config.scan_layers: + return x + + is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace") + if is_nnx_leaf(x): + metadata = x.get_metadata() + updates = {} + # Determine which axis to insert 'layers' into based on config + scan_axis = self.config.param_scan_axis + + for sharding_key in ["sharding", "out_sharding", "sharding_names"]: + axes = metadata.get(sharding_key) + if isinstance(axes, jax.sharding.PartitionSpec): + spec_list = list(axes) + if "layers" not in spec_list: + # Insert 'layers' at the correct scan axis position + # Cap at current length to avoid index out of bounds + insert_pos = min(scan_axis, len(spec_list)) + spec_list.insert(insert_pos, "layers") + updates[sharding_key] = jax.sharding.PartitionSpec(*spec_list) + + if updates: + return x.replace(**updates) + return x + + for i in range(length): + # Slice both values AND logical metadata! + is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace") + curr_params = jax.tree.map(lambda x: _slice_and_unpromote(x, i), params, is_leaf=is_nnx_leaf) + curr_state = jax.tree.map(lambda x: _slice_and_unpromote(x, i), state, is_leaf=is_nnx_leaf) + + curr_carry, (out_p, out_o) = layer_fn(curr_carry, (curr_params, curr_state)) + + # Promote ALL parameters back to rank-3 metadata immediately + # This ensures they are ready to be stacked correctly. + out_p = jax.tree.map(_promote_to_scanned, out_p, is_leaf=is_nnx_leaf) + out_o = jax.tree.map(_promote_to_scanned, out_o, is_leaf=is_nnx_leaf) + + out_params_list.append(out_p) + out_other_list.append(out_o) + + final_carry = curr_carry + scanned_params = jax.tree.map(lambda *args: jnp.stack(args), *out_params_list) + scanned_other = jax.tree.map(lambda *args: jnp.stack(args), *out_other_list) + - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + else: + layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + def _ensure_scan_leading_axis(x): + if not hasattr(x, "shape") or len(x.shape) == 0: + return jnp.broadcast_to(x, (length,)) + return x + + params = jax.tree.map(_ensure_scan_leading_axis, params) + state = jax.tree.map(_ensure_scan_leading_axis, state) + + final_carry, (scanned_params, scanned_other) = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) - return final_carry, nnx.merge(graphdef, scanned_state) + scan_axis = self.config.param_scan_axis + + def _force_promote(x): + is_nnx_leaf = hasattr(x, "get_metadata") and hasattr(x, "replace") + if is_nnx_leaf: + metadata = x.get_metadata() + updates = {} + val_ndim = x.value.ndim + for sharding_key in ["sharding", "out_sharding", "sharding_names"]: + axes = metadata.get(sharding_key) + if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)): + l = list(axes) + if len(l) < val_ndim and "layers" not in l: + pos = min(scan_axis, len(l)) + l.insert(pos, "layers") + updates[sharding_key] = jax.sharding.PartitionSpec(*l) if isinstance(axes, jax.sharding.PartitionSpec) else tuple(l) + if updates: + return x.replace(**updates) + return x + + is_leaf_with_metadata = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace") + scanned_params = jax.tree.map(_force_promote, scanned_params, is_leaf=is_leaf_with_metadata) + scanned_other = jax.tree.map(_force_promote, scanned_other, is_leaf=is_leaf_with_metadata) + + if dynamic_graph_init: + # Perform a structural update: merge the new structure with the stacked arrays + out_layers = nnx.merge(updated_graphdef[0], scanned_params, scanned_other) + + # We must update the PARENT (self) to point to the new structure. + for attr_name, attr_val in self.__dict__.items(): + if attr_val is layers: + setattr(self, attr_name, out_layers) + print(f"[DEBUG] Materialization complete: updated self.{attr_name}") + break + + # FORCE NNX to recognize new structural changes by splitting/merging the PARENT + # This updates the underlying GraphDef for the entire Decoder. + g, s = nnx.split(self) + new_self = nnx.merge(g, s) + nnx.update(self, nnx.state(new_self)) + else: + nnx.update(layers, nnx.State.merge(scanned_params, scanned_other)) + out_layers = layers + + return final_carry, out_layers def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -904,8 +1087,17 @@ def __call__( audio_embeddings: None | jnp.ndarray = None, audio_masks: None | jnp.ndarray = None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input: MultimodalInput | None = None, ): cfg = self.config + + if multimodal_input is not None: + image_embeddings = image_embeddings or multimodal_input.image_embeddings + image_masks = image_masks or multimodal_input.image_masks + audio_embeddings = audio_embeddings or multimodal_input.audio_embeddings + audio_masks = audio_masks or multimodal_input.audio_masks + bidirectional_mask = bidirectional_mask or multimodal_input.bidirectional_mask + assert decoder_input_tokens.ndim == 2 # [batch, len] policy = self.get_remat_policy() diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..05c43166b8 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -49,6 +49,7 @@ from tunix.sft import metrics_logger, peft_trainer, profiler +from maxtext.optimizers import optimizers from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( @@ -60,8 +61,8 @@ maybe_record_goodput, record_goodput, ) -from maxtext.optimizers import optimizers from maxtext.trainers.post_train.sft import hooks +from maxtext.utils import lora_utils from maxtext.utils import max_utils from maxtext.utils import max_logging from maxtext.utils import maxtext_utils @@ -126,7 +127,15 @@ def use_maxtext_loss_function(trainer, mt_config): The trainer configured with the MaxText loss function. """ - def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targets_position, targets_segmentation): + def loss_func( + model, + inputs, + inputs_position, + inputs_segmentation, + targets, + targets_position, + targets_segmentation, + ): data = { "inputs": inputs, "inputs_position": inputs_position, @@ -147,6 +156,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + if mt_config.lora.enable_lora: + model = lora_utils.apply_lora_to_model(model, mesh, mt_config) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -162,6 +173,9 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + if mt_config.lora.lora_restore_path: + trainer = lora_utils.restore_lora_from_path(trainer, mt_config) + trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) @@ -172,7 +186,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + trainer.train( + trainer.data_hooks.train_data_iterator, + trainer.data_hooks.eval_data_iterator, + ) return trainer @@ -203,7 +220,7 @@ def main(argv: Sequence[str]) -> None: pathwaysutils.initialize() os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - mt_config = pyconfig.initialize(argv) + mt_config = pyconfig.initialize_pydantic(argv) max_utils.print_system_information() goodput_recorder = create_goodput_recorder(mt_config) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..9536c26d05 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +13,26 @@ # limitations under the License. """ Common LoRA utils needed to support LoRA adapters.""" - from functools import partial import json +import os +import re +from typing import Any, Optional +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from flax.training import train_state import jax import jax.numpy as jnp +import omegaconf +from orbax import checkpoint as ocp +import qwix -from flax.training import train_state -from flax.linen import partitioning as nn_partitioning - -from maxtext.common import checkpointing +from maxtext.configs import pyconfig from maxtext.utils import gcs_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -from maxtext.utils import max_logging def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): @@ -243,7 +248,11 @@ def get_lora_param_shape(base_array_shape, lora_rank, lora_module): f"Encountered unexpected shape={base_array_shape} of array in base params. Array dimensions > 4 not supported." ) - if lora_module in ["self_attention.query", "self_attention.key", "self_attention.value"]: + if lora_module in [ + "self_attention.query", + "self_attention.key", + "self_attention.value", + ]: lora_a_shape = base_array_shape[:-2] + (lora_rank,) lora_b_shape = (lora_rank,) + base_array_shape[1:] elif lora_module in ["self_attention.out"]: @@ -270,7 +279,11 @@ def get_lora_param_sharding(base_param_sharding, lora_module): base_memory_kind = base_param_sharding.memory_kind base_pspec = base_param_sharding.spec - if lora_module in ["self_attention.query", "self_attention.key", "self_attention.value"]: + if lora_module in [ + "self_attention.query", + "self_attention.key", + "self_attention.value", + ]: lora_a_pspec_tuple = base_pspec[:-2] + ((),) lora_a_pspec = jax.sharding.PartitionSpec(*lora_a_pspec_tuple) @@ -311,7 +324,13 @@ def add_lora_params(lora_params, module_name, base_params, lora_rank, lora_targe for name, param in base_params.items(): if isinstance(param, dict): lora_params[name] = {} - add_lora_params(lora_params[name], f"{module_name}.{name}", param, lora_rank, lora_target_modules) + add_lora_params( + lora_params[name], + f"{module_name}.{name}", + param, + lora_rank, + lora_target_modules, + ) else: if name not in ["kernel", "scale", "embedding"]: raise ValueError(f"Unexpected key={name} exists in the abstract params of base model.") @@ -349,3 +368,231 @@ def get_lora_annotations(lora_abstract_params): ) return unboxed_abstract_lora_state, lora_state_mesh_annotations + + +# --- Qwix LoRA Utils --- + + +def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str: + """Gets the regex for modules to apply LoRA on based on the model name.""" + lora_cfg = getattr(mt_config, "lora", mt_config) + if getattr(lora_cfg, "lora_module_path", ""): + return lora_cfg.lora_module_path + + config_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "configs", + "post_train", + "lora_module_path.yml", + ) + lora_configs = omegaconf.OmegaConf.load(config_path) + model_name = mt_config.model_name.lower() + + for key, module_path in lora_configs.items(): + if key != "default" and model_name.startswith(key): + # Make the layer index optional to support both scanned and non-scanned paths + # e.g., 'decoder/layers/0/mlp' vs 'decoder/layers/mlp' + flexible_path = str(module_path).replace("layers/", "layers/(?:[0-9]+/)?") + max_logging.log(f"Auto-detected lora_module_path for model '{model_name}': {flexible_path}") + return flexible_path + + default_path = lora_configs.get( + "default", + "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))", + ) + flexible_default = str(default_path).replace("layers/", "layers/(?:[0-9]+/)?") + max_logging.log( + f"Warning: Model '{model_name}' is not in the list of verified LoRA models. " + "Auto-detection might not work. Please provide an explicit `lora_module_path` in your config if training fails." + ) + max_logging.log(f"Falling back to flexible default lora_module_path: {flexible_default}") + return flexible_default + + +def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvider: + """Builds a Qwix LoRA provider from MaxText LoRA settings.""" + lora_cfg = getattr(mt_config, "lora", mt_config) + lora_module_path = _get_lora_module_path(mt_config) + lora_kwargs = { + "module_path": lora_module_path, + "rank": lora_cfg.lora_rank, + "alpha": lora_cfg.lora_alpha, + "dropout": 0.0, + } + if lora_cfg.lora_tile_size is not None: + lora_kwargs["tile_size"] = lora_cfg.lora_tile_size + if lora_cfg.lora_weight_qtype is not None: + # TODO(jackyf): QLoRA is currently not working. + lora_kwargs["weight_qtype"] = lora_cfg.lora_weight_qtype + max_logging.log( + f"QLoRA configured: module_path={lora_module_path} " + f"rank={lora_cfg.lora_rank} alpha={lora_cfg.lora_alpha} " + f"weight_qtype={lora_cfg.lora_weight_qtype} " + f"tile_size={lora_cfg.lora_tile_size}" + ) + else: + max_logging.log( + f"LoRA configured: module_path={lora_module_path} " + f"rank={lora_cfg.lora_rank} alpha={lora_cfg.lora_alpha} " + f"tile_size={lora_cfg.lora_tile_size}" + ) + + return qwix.LoraProvider(**lora_kwargs) + + +def _prepare_dummy_inputs() -> tuple[jnp.ndarray, jnp.ndarray]: + """Builds dummy decoder inputs used to materialize LoRA parameters.""" + # Keep LoRA warmup as small as possible to minimize compile/memory overhead. + dummy_bs = 1 + seq_len = 1 + decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + decoder_positions = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + return decoder_input_tokens, decoder_positions + + +def is_lora_enabled(model: nnx.Module) -> bool: + """Checks if the model has LoRA parameters.""" + for _, value in nnx.iter_graph(model): + if isinstance(value, nnx.LoRAParam): + return True + return False + + +def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters): + """Validates that LoRA is active or that target modules were matched.""" + + if is_lora_enabled(lora_model): + return + + lora_module_path = _get_lora_module_path(mt_config) + compiled_module_path = re.compile(lora_module_path) + matched_module_paths = [] + sample_module_paths = [] + + for path, _ in nnx.iter_graph(lora_model): + module_path = "/".join(str(p) for p in path) + if len(sample_module_paths) < 100: + sample_module_paths.append(module_path) + if compiled_module_path.search(module_path): + matched_module_paths.append(module_path) + + if not matched_module_paths: + max_logging.log( + f"LoRA module_path='{lora_module_path}' did not match any weights. " f"Sample module paths: {sample_module_paths}" + ) + raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") + + raise ValueError( + "LoRA module path matched target modules, but nnx.LoRAParam is still " + "missing. For Tunix PeftTrainer, LoRA params must be materialized before " + "trainer initialization, otherwise it falls back to full-model training. " + f"Sample matches: {matched_module_paths[:10]}" + ) + + +def _patch_qwix_for_maxtext(mesh, mt_config): + import qwix._src.flax_util as flax_util + import qwix._src.providers.ptq as ptq + import jax.numpy as jnp + from flax import nnx + + # 1. PTQ patch + original_get_intercept_map = ptq.PtqProvider.get_intercept_map + + def patched_get_intercept_map(self): + mapping = original_get_intercept_map(self) + + def intercept_asarray(a, dtype=None, order=None, **kwargs): + if isinstance(a, nnx.State) and 'array' in a: + a = a['array'] + if isinstance(a, nnx.State) and 'qvalue' in a and 'scale' in a: + a = ptq.QArray(qvalue=a['qvalue'].value, scale=a['scale'].value) + + if type(a).__name__ in ("WithAux", "QArray"): + return a + return jnp.asarray(a, dtype=dtype, order=order, **kwargs) + + mapping["jax.numpy.asarray"] = intercept_asarray + return mapping + + ptq.PtqProvider.get_intercept_map = patched_get_intercept_map + + +def apply_lora_to_model( + model: nnx.Module, + mesh: Optional[jax.sharding.Mesh], + mt_config: pyconfig.HyperParameters, +) -> nnx.Module: + """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" + lora_cfg = getattr(mt_config, "lora", mt_config) + # Skip Qwix LoRA if MaxText LoRA adapters are loaded + if getattr(mt_config, "lora_input_adapters_path", None): + max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") + return model + + if not getattr(lora_cfg, "enable_lora", False): + return model + + _patch_qwix_for_maxtext(mesh, mt_config) + + lora_provider = _build_lora_provider(mt_config) + + model_rngs = getattr(model.decoder, "rngs", None) + decoder_input_tokens, decoder_positions = _prepare_dummy_inputs() + + lora_model = qwix.apply_lora_to_model( + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + rngs=model_rngs, + ) + + if mesh is not None: + with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + graph_def, state = nnx.split(lora_model) + default_memory_kind = jax.devices()[0].default_memory().kind + dst_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(mesh, x, memory_kind=default_memory_kind) if x is not None else None, + nnx.get_partition_spec(state), + ) + from tunix.rl import reshard # pylint: disable=import-outside-toplevel + + state = reshard.reshard_pytree(state, dst_shardings) + lora_model = nnx.merge(graph_def, state) + + _verify_lora_parameters(lora_model, mt_config) + + return lora_model + + +def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: + """Restores LoRA params from an external checkpoint item path. + + This method assumes mt_config.lora.lora_restore_path is set and that the run + is a fresh start (trainer.train_steps == 0). + """ + lora_restore_path = mt_config.lora.lora_restore_path + + train_steps = getattr(trainer, "train_steps", 0) + if train_steps > 0: + max_logging.log( + f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'." + ) + return trainer + + if not is_lora_enabled(trainer.model): + lora_module_path = _get_lora_module_path(mt_config) + raise ValueError( + "lora_restore_path is set but LoRA is not enabled on the model. " + f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." + ) + + abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) + restored_lora_params = ocp.StandardCheckpointer().restore( + lora_restore_path, + target=abstract_lora_params, + ) + nnx.update(trainer.model, restored_lora_params) + max_logging.log(f"LoRA restore complete from '{lora_restore_path}'. " "Trainer step remains at 0 for this run.") + return trainer diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py new file mode 100644 index 0000000000..af3ff1978f --- /dev/null +++ b/tests/post_training/unit/lora_utils_test.py @@ -0,0 +1,231 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Qwix LoRA utils in lora_utils.py""" +import sys +import unittest +from unittest import mock +import jax +import optax +import pytest +from flax import nnx + +# Skip the entire test suite if dependencies are missing +pytestmark = [pytest.mark.post_training] + +# Now safe to do top-level imports +from tunix.sft import peft_trainer +from maxtext.utils import lora_utils +from maxtext.utils import model_creation_utils +from maxtext.configs import pyconfig +from tests.utils.test_helpers import get_decoupled_parallelism_overrides, get_test_config_path # pylint: disable=no-name-in-module + +# --------------------------------------------------------------------------- +# Shared minimal config overrides +# --------------------------------------------------------------------------- +_BASE_CONFIG = { + "per_device_batch_size": 1.0, + "run_name": "lora_utils_test", + "enable_checkpointing": False, + "base_num_decoder_layers": 1, + "attention": "dot_product", + "max_target_length": 8, + "base_emb_dim": 128, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "base_mlp_dim": 256, + "max_prefill_predict_length": 4, + "model_name": "llama2-7b", + "enable_nnx": True, + "pure_nnx_decoder": True, +} + + +def _make_config(**overrides): + """Return a MaxTextConfig object suitable for unit tests.""" + extra_args = get_decoupled_parallelism_overrides() + # Use initialize_pydantic to get nested models as objects (attribute access) + return pyconfig.initialize_pydantic( + [sys.argv[0], get_test_config_path()], + **_BASE_CONFIG, + **extra_args, + **overrides, + ) + + +class LoraUtilsTest(unittest.TestCase): + """Tests for lora_utils.py (Qwix LoRA Utils)""" + + # pylint: disable=protected-access + + def test_get_lora_module_path(self): + """Test retrieving LoRA module path from config.""" + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.lora.lora_module_path = "" + + mock_config.model_name = "llama3.1-8b" + path = lora_utils._get_lora_module_path(mock_config) + self.assertEqual( + path, + "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))", + ) + + mock_config.model_name = "unknown_model" + # Ensure lora.lora_module_path is still empty string to trigger fallback + mock_config.lora.lora_module_path = "" + path = lora_utils._get_lora_module_path(mock_config) + # Fallback to default + self.assertEqual( + path, + "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))", + ) + + mock_config.lora.lora_module_path = "custom/path" + path = lora_utils._get_lora_module_path(mock_config) + self.assertEqual(path, "custom/path") + + def test_build_lora_provider(self): + """Test building Qwix LoraProvider from config.""" + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.lora.lora_module_path = "custom/path" + mock_config.lora.lora_rank = 8 + mock_config.lora.lora_alpha = 16.0 + mock_config.lora.lora_tile_size = None + mock_config.lora.lora_weight_qtype = None + + with mock.patch("qwix.LoraProvider") as mock_provider: + lora_utils._build_lora_provider(mock_config) + mock_provider.assert_called_once_with(module_path="custom/path", rank=8, alpha=16.0, dropout=0.0) + + def test_prepare_dummy_inputs(self): + """Test preparation of dummy inputs for LoRA verification.""" + tokens, positions = lora_utils._prepare_dummy_inputs() + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(positions.shape, (1, 1)) + + def test_verify_lora_parameters_enabled(self): + """Test verification of LoRA parameters when enabled.""" + mock_model = mock.MagicMock() + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + + # Note: we use our local is_lora_enabled now + with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True): + # Should not raise + lora_utils._verify_lora_parameters(mock_model, mock_config) + + def test_verify_lora_parameters_not_enabled_no_match(self): + """Test verification fails when LoRA parameters are expected but not found.""" + mock_model = mock.MagicMock() + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.model_name = "llama" + mock_config.lora.lora_module_path = "non_existent" + + with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False): + mock_model.iter_modules.return_value = [] + with self.assertRaisesRegex(ValueError, "no LoRA parameters found"): + lora_utils._verify_lora_parameters(mock_model, mock_config) + + def test_apply_lora_to_model_disabled(self): + """Test applying LoRA when it is disabled in config.""" + cfg = _make_config(lora={"enable_lora": False}) + model, _ = model_creation_utils.create_nnx_model(cfg, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + # Pydantic MaxTextConfig supports direct attribute access + self.assertFalse(cfg.lora.enable_lora) + result = lora_utils.apply_lora_to_model(model, None, cfg) + self.assertEqual(result, model) + self.assertFalse(lora_utils.is_lora_enabled(result)) + + def test_apply_lora_to_model_adapters_loaded(self): + """Test applying LoRA when adapters are already provided.""" + cfg = _make_config(**{"lora_input_adapters_path": "some/path"}) + model, _ = model_creation_utils.create_nnx_model(cfg, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + result = lora_utils.apply_lora_to_model(model, None, cfg) + self.assertEqual(result, model) + # is_lora_enabled checks for LoRAParam which Qwix adds. + # If we skip Qwix, it should stay False. + self.assertFalse(lora_utils.is_lora_enabled(result)) + + def _run_apply_lora_test(self, scan_layers: bool): + """Helper to run LoRA application test with/without scanned layers.""" + # Passing nested dict as 'lora' kwarg to _make_config + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_rank": 4, + "lora_alpha": 8.0, + "lora_module_path": ".*mlp/wi_.*", + }, + scan_layers=scan_layers, + ) + + # Create a real small model using standard creation utils + model, _ = model_creation_utils.create_nnx_model(cfg, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + + # Verify model is NOT lora enabled initially + self.assertFalse(lora_utils.is_lora_enabled(model)) + + # Apply LoRA + lora_model = lora_utils.apply_lora_to_model(model, model.mesh, cfg) + + # Verify we can find LoRAParam in the state + _, state = nnx.split(lora_model) + lora_params = nnx.filter_state(state, nnx.LoRAParam) + self.assertGreater(len(jax.tree_util.tree_leaves(lora_params)), 0) + + # Verify it IS now LoRA enabled + self.assertTrue(lora_utils.is_lora_enabled(lora_model)) + + # Test fit for PeftTrainer + trainer_cfg = peft_trainer.TrainingConfig(eval_every_n_steps=10) + optimizer = optax.adam(1e-4) + + # This instantiation will fail if wrt=nnx.LoRAParam cannot find any matching params + trainer = peft_trainer.PeftTrainer(model=lora_model, optimizer=optimizer, training_config=trainer_cfg) + + # Verify optimizer is indeed targeting LoRAParams + opt_state = nnx.state(trainer.optimizer) + self.assertGreater(len(jax.tree_util.tree_leaves(opt_state)), 0) + + def test_apply_lora_to_model_scan_layers_false(self): + """Test applying LoRA to model with scan_layers=False.""" + self._run_apply_lora_test(scan_layers=False) + + def test_apply_lora_to_model_scan_layers_true(self): + """Test applying LoRA to model with scan_layers=True.""" + self._run_apply_lora_test(scan_layers=True) + + def test_restore_lora_from_path(self): + """Test restoration of LoRA parameters from a path.""" + cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "some/path"}) + model, _ = model_creation_utils.create_nnx_model(cfg, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + model = lora_utils.apply_lora_to_model(model, None, cfg) + + trainer = mock.MagicMock() + trainer.model = model + trainer.train_steps = 0 + + restored_state = nnx.state(model, nnx.LoRAParam) + + with mock.patch("orbax.checkpoint.StandardCheckpointer.restore", return_value=restored_state) as mock_restore: + with mock.patch("flax.nnx.update") as mock_update: + lora_utils.restore_lora_from_path(trainer, cfg) + mock_restore.assert_called_once_with("some/path", target=mock.ANY) + mock_update.assert_called_once() + + +if __name__ == "__main__": + unittest.main()