Skip to content
Merged
124 changes: 124 additions & 0 deletions .ci/scripts/test_lora_multimethod.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -exu
# shellcheck source=/dev/null
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"

cmake_install_executorch_libraries() {
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
rm -rf cmake-out
cmake --workflow llm-release
}

cmake_build_llama_runner() {
echo "Building llama runner"
pushd extension/llm/tokenizers
echo "Updating tokenizers submodule"
git submodule update --init
popd
make llama-cpu
}

cleanup_files() {
echo "Deleting downloaded and generated files"
rm -rf "${HF_QWEN_PATH}/"
rm -rf "${HF_ADAPTER_PATH}/"
rm -rf *.pte
rm -f result*.txt
}

# Download LoRA adapter.
python -m pip install -q huggingface_hub
Comment thread
lucylq marked this conversation as resolved.
HF_ADAPTER_REPO="lucylq/qwen3_06B_lora_math"
HF_ADAPTER_PATH=$(
bash "$(dirname "${BASH_SOURCE[0]}")/download_hf_hub.sh" \
--model_id "${HF_ADAPTER_REPO}" \
--files "adapter_config.json" "adapter_model.safetensors"
)

# Download base model (for tokenizer path).
HF_QWEN_PATH=$(python -c "from huggingface_hub import snapshot_download; print(snapshot_download('unsloth/Qwen3-0.6B'))")
echo "Model downloaded to: $HF_QWEN_PATH"

### EXPORT MULTIMETHOD PTE ###
# Set environment variables for OmegaConf interpolation in yaml.
export LORA_ADAPTER_CHECKPOINT="${HF_ADAPTER_PATH}/adapter_model.safetensors"
export LORA_ADAPTER_CONFIG="${HF_ADAPTER_PATH}/adapter_config.json"

$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
--config examples/models/qwen3/config/qwen3_multimethod.yaml

### BUILD LLAMA RUNNER ###
cmake_install_executorch_libraries
cmake_build_llama_runner

# Runner constants.
RUNTIME_ARGS="--tokenizer_path=${HF_QWEN_PATH}/ --temperature=0 --seq_len=100 --warmup=1"
PROMPT="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant"

# Expected outputs.
EXPECTED_LORA_PREFIX="
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
To calculate 15% of 80"

EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
<think>
Okay, so I need to calculate 15% of 80."

### TEST 1: Run lora_forward method ###
NOW=$(date +"%H:%M:%S")
echo "Test 1: Multimethod lora_forward. Starting at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main \
--model_path=multimethod_qwen.pte \
--method_name=lora_forward \
--prompt="${PROMPT}" \
${RUNTIME_ARGS} > result_lora.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT=$(cat result_lora.txt)
if [[ "${RESULT}" == "${EXPECTED_LORA_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_LORA_PREFIX}"
echo "Actual result: ${RESULT}"
echo "Test 1 (lora_forward): Success"
else
echo "Expected result prefix: ${EXPECTED_LORA_PREFIX}"
echo "Actual result: ${RESULT}"
echo "Test 1 (lora_forward): Failure"
cleanup_files
exit 1
fi

### TEST 2: Run base_forward method ###
NOW=$(date +"%H:%M:%S")
echo "Test 2: Multimethod base_forward. Starting at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main \
--model_path=multimethod_qwen.pte \
--method_name=base_forward \
--prompt="${PROMPT}" \
${RUNTIME_ARGS} > result_base.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT=$(cat result_base.txt)
if [[ "${RESULT}" == "${EXPECTED_BASE_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_BASE_PREFIX}"
echo "Actual result: ${RESULT}"
echo "Test 2 (base_forward): Success"
else
echo "Expected result prefix: ${EXPECTED_BASE_PREFIX}"
echo "Actual result: ${RESULT}"
echo "Test 2 (base_forward): Failure"
cleanup_files
exit 1
fi

echo "Multimethod tests passed!"
cleanup_files
1 change: 1 addition & 0 deletions examples/models/llama/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ fbcode_target(_kind = runtime.python_library,
fbcode_target(_kind = runtime.python_library,
name = "export_library",
srcs = [
"convert_weights.py",
"export_llama.py",
"export_llama_lib.py",
"model.py",
Expand Down
138 changes: 137 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
from importlib import resources as _resources
from json import JSONDecodeError
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch
from torch.export import ExportedProgram

from executorch.devtools.backend_debug import print_delegation_info
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
from executorch.examples.models.llama.hf_download import (
download_and_convert_hf_checkpoint,
)
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from executorch.extension.llm.export.config.llm_config import LlmConfig
Expand Down Expand Up @@ -844,6 +846,28 @@ def _validate_args(llm_config):
"Shared embedding is only supported with torchao quantization."
)

if llm_config.multimethod.enabled:
if llm_config.base.lora is not None:
raise ValueError(
"Cannot use both base.lora and multimethod.methods. "
"Use multimethod.methods for all LoRA variants."
)
if llm_config.quantization.pt2e_quantize is not None:
raise ValueError(
"PT2E quantization is not supported with multimethod export."
)
if (
llm_config.backend.coreml.enabled
or llm_config.backend.vulkan.enabled
or llm_config.backend.qnn.enabled
or llm_config.backend.mps.enabled
or llm_config.backend.openvino.enabled
):
raise ValueError(
"Multimethod export only supports XNNPACK backend or portable ops"
"Please disable other backends (coreml, vulkan, qnn, mps, openvino)."
Comment thread
lucylq marked this conversation as resolved.
Outdated
)


def _to_edge_and_lower_llama_xnnpack(
builder_exported,
Expand Down Expand Up @@ -1107,9 +1131,121 @@ def _to_edge_and_lower_llama( # noqa: C901
return builder


def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List]:
Comment thread
lucylq marked this conversation as resolved.
Outdated
"""Get XNNPACK partitioners for multimethod export."""
partitioners = []

if llm_config.backend.xnnpack.enabled:
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
if llm_config.backend.xnnpack.extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)

return partitioners if partitioners else None


def _get_output_filename(llm_config: LlmConfig, modelname: str, output_dir: str, dtype: DType) -> str:
"""Determine output filename for the .pte file."""
if dtype == DType.fp16:
modelname = f"{modelname}_h"

if llm_config.export.output_name:
output_name = llm_config.export.output_name
if output_name.endswith(".pte"):
return output_name
else:
return f"{output_dir}/{output_name}.pte"
else:
return f"{output_dir}/{modelname}.pte"


def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
"""
Export multiple methods (base + LoRA variants) to a single .pte file.

For each method in llm_config.multimethod.methods:
- If LoraConfig is None: use base model
- If LoraConfig is provided: create model with LoRA weights

Limitations:
- Only XNNPACK backend is supported for multimethod export.
- PT2E quantization is not supported.
- Each method is exported separately; export time scales linearly
with the number of methods.
- The final .pte file deduplicates shared weights automatically.
"""
num_methods = len(llm_config.multimethod.methods)
logging.info(
f"Multimethod export: exporting {num_methods} method(s). "
"Each method requires separate model instantiation and export."
)

additional_passes = []
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# Build dict of exported programs
method_to_program: Dict[str, ExportedProgram] = {}
first_builder = None

for method_name, lora_config in llm_config.multimethod.methods.items():
logging.info(f"Exporting method: {method_name}")

# Create a copy of config with this method's LoRA setting
method_config = copy.deepcopy(llm_config)
method_config.base.lora = lora_config
# Disable multimethod to avoid infinite recursion
method_config.multimethod.methods = {}

# Load and prepare model for this method
builder = _prepare_for_llama_export(method_config)
builder = builder.export()
builder.run_canonical_optimizations()

# Get the exported program
exported_program = builder._export(builder.pre_autograd_graph_module)
method_to_program[method_name] = exported_program

if first_builder is None:
first_builder = builder

assert first_builder is not None, "No methods to export"

# Get partitioners based on backend config
partitioners = _get_xnnpack_partitioners(llm_config)

# Lower all methods together using multimethod API
edge_config = first_builder._get_edge_config()
edge_manager = to_edge_transform_and_lower(
method_to_program,
partitioner=partitioners,
compile_config=edge_config,
constant_methods=first_builder.metadata,
Comment thread
lucylq marked this conversation as resolved.
)

# Convert to executorch and save
first_builder.edge_manager = edge_manager
first_builder = first_builder.to_executorch(passes=additional_passes)

output_file = _get_output_filename(
llm_config,
first_builder.modelname,
first_builder.output_dir,
first_builder.dtype,
)
first_builder.save_to_pte(output_file)

return first_builder


def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
_validate_args(llm_config)

# Check for multimethod export
if llm_config.multimethod.enabled:
return _export_llama_multimethod(llm_config)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def define_common_targets():
"//executorch/examples/models/llama/tokenizer:tiktoken",
"//pytorch/tokenizers:llama2c_tokenizer",
"//pytorch/tokenizers:hf_tokenizer",
"//pytorch/tokenizers:regex_lookahead",
] + (_get_operator_lib(aten)) + ([
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
# Therefore enable it explicitly for now to avoid failing tests
Expand Down
28 changes: 28 additions & 0 deletions examples/models/qwen3/config/qwen3_multimethod.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
base:
model_class: "qwen3_0_6b"
params: "examples/models/qwen3/config/0_6b_config.json"
metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}'

model:
use_kv_cache: true
use_sdpa_with_kv_cache: true

export:
output_name: multimethod_qwen

backend:
xnnpack:
enabled: true

quantization:
qmode: "8da4w"
group_size: 32

multimethod:
methods:
# LoRA method - adapter paths from environment variables
lora_forward:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
# Base method - no LoRA
base_forward: null
Loading