Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,15 +897,15 @@ def _validate_args(llm_config):
"Shared embedding is only supported with torchao quantization."
)

if llm_config.multimethod_lora.enabled:
if llm_config.multimethod.enabled:
if llm_config.base.lora_config is not None:
raise ValueError(
"Cannot use both base.lora_config and multimethod_lora.methods. "
"Use multimethod_lora.methods for all LoRA variants."
"Cannot use both base.lora_config and multimethod.methods. "
"Use multimethod.methods for all LoRA variants."
)
if llm_config.quantization.pt2e_quantize is not None:
raise ValueError(
"PT2E quantization is not supported with multimethod_lora export."
"PT2E quantization is not supported with multimethod export."
)
if (
llm_config.backend.coreml.enabled
Expand All @@ -915,7 +915,7 @@ def _validate_args(llm_config):
or llm_config.backend.openvino.enabled
):
raise ValueError(
"multimethod_lora export only supports XNNPACK backend or portable ops"
"multimethod export only supports XNNPACK backend or portable ops. "
"Please disable other backends (coreml, vulkan, qnn, mps, openvino)."
)

Expand Down Expand Up @@ -1230,7 +1230,7 @@ def _to_edge_and_lower_llama( # noqa: C901


def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List[Partitioner]]:
"""Get XNNPACK partitioners for multimethod_lora export."""
"""Get XNNPACK partitioners for multimethod export."""
partitioners = []

# Order matters here, dynamic quantization should be applied first when
Expand Down Expand Up @@ -1268,20 +1268,20 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
"""
Export multiple methods (base + LoRA variants) to a single .pte file.

For each method in llm_config.multimethod_lora.methods:
For each method in llm_config.multimethod.methods:
- If LoraConfig is None: use base model
- If LoraConfig is provided: create model with LoRA weights

Limitations:
- Only XNNPACK backend is supported for multimethod_lora export.
- Only XNNPACK backend is supported for multimethod export.
- PT2E quantization is not supported.
- Each method is exported separately; export time scales linearly
with the number of methods.
- The final .pte file deduplicates shared weights automatically.
"""
num_methods = len(llm_config.multimethod_lora.methods)
num_methods = len(llm_config.multimethod.methods)
logging.info(
f"multimethod_lora export: exporting {num_methods} method(s). "
f"multimethod export: exporting {num_methods} method(s). "
"Each method requires separate model instantiation and export."
)

Expand All @@ -1293,14 +1293,14 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
method_to_program: Dict[str, ExportedProgram] = {}
first_builder = None

for method_name, lora_config in llm_config.multimethod_lora.methods.items():
logging.info(f"Exporting method: {method_name}")
for method in llm_config.multimethod.methods:
logging.info(f"Exporting method: {method.method_name}")

# Create a copy of config with this method's LoRA setting
method_config = copy.deepcopy(llm_config)
method_config.base.lora_config = lora_config
# Disable multimethod_lora to avoid infinite recursion
method_config.multimethod_lora.methods = {}
method_config.base.lora_config = method.lora_config
# Disable multimethod to avoid infinite recursion
method_config.multimethod.methods = []

# Load and prepare model for this method
builder = _prepare_for_llama_export(method_config)
Expand All @@ -1309,7 +1309,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:

# Get the exported program
exported_program = builder._export(builder.pre_autograd_graph_module)
method_to_program[method_name] = exported_program
method_to_program[method.method_name] = exported_program
Comment on lines 1309 to +1312
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With multimethod.methods now being a list, duplicate method_name entries are possible. The current code will silently overwrite earlier entries in method_to_program[...], but num_methods/logging still counts duplicates, which can lead to confusing or incorrect exports. Please validate that all method.method_name values are unique (and non-empty) before populating method_to_program, and raise a clear error if duplicates are found.

Copilot uses AI. Check for mistakes.

if first_builder is None:
first_builder = builder
Expand All @@ -1319,7 +1319,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
# Get partitioners based on backend config
partitioners = _get_xnnpack_partitioners(llm_config)

# Lower all methods together using multimethod_lora API
# Lower all methods together using multimethod API
edge_config = first_builder._get_edge_config()
edge_manager = to_edge_transform_and_lower(
method_to_program,
Expand All @@ -1333,7 +1333,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
first_builder.edge_manager = edge_manager
first_builder = first_builder.to_executorch(
passes=additional_passes,
share_mutable_buffers=llm_config.multimethod_lora.share_mutable_buffers,
share_mutable_buffers=llm_config.multimethod.share_mutable_buffers,
)

output_file = _get_output_filename(
Expand All @@ -1350,8 +1350,8 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
_validate_args(llm_config)

# Check for multimethod_lora export
if llm_config.multimethod_lora.enabled:
# Check for multimethod export
if llm_config.multimethod.enabled:
return _export_llama_multimethod(llm_config)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
Expand Down
11 changes: 6 additions & 5 deletions examples/models/qwen3/config/qwen3_multimethod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ quantization:
qmode: "8da4w"
group_size: 32

multimethod_lora:
multimethod:
methods:
# LoRA method - adapter paths from environment variables
lora_forward:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
- method_name: lora_forward
lora_config:
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
# Base method - no LoRA
base_forward: null
- method_name: base_forward
share_mutable_buffers: True
39 changes: 26 additions & 13 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import ClassVar, Dict, List, Optional

Check warning on line 25 in extension/llm/export/config/llm_config.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F401

'typing.Dict' imported but unused See https://www.flake8rules.com/rules/F401.html.


################################################################################
Expand Down Expand Up @@ -293,37 +293,52 @@


################################################################################
############################## MultimethodLoraConfig ###########################
############################## MultimethodConfig ###########################
################################################################################


@dataclass
class MultimethodLoraConfig:
class MethodConfig:
"""Configuration for exporting a single method to a .pte file.
By default, all other fields fall back to the default configs in
the yaml file.

Attributes:
method_name: Name of the method to export.
lora_config: Optional LoRA configuration.
"""

method_name: str
lora_config: Optional[LoraConfig] = None


@dataclass
class MultimethodConfig:
"""Configuration for exporting multiple methods to a single .pte file.

Maps method names to optional LoRA configurations. A None value means
the method uses base model weights.

Comment on lines 317 to 321
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MultimethodConfig docstring still says it "Maps method names" even though methods is now a List[MethodConfig]. This is misleading for config authors; please update the docstring wording to reflect the list-based schema (and, if relevant, that each list element’s lora_config=None means base weights).

Copilot uses AI. Check for mistakes.
Attributes:
methods: Dict mapping method names to optional LoRA configs.
Empty dict disables multimethod_lora export.
methods: List of MethodConfig objects with method name and config
for each method.
share_mutable_buffers: Whether to share mutable buffers across methods.
If True, sets all mutable buffers to mem_id=2. Mutable buffers with
the same FQN (fully qualified name) will have the same offset.

Example:
MultimethodLoraConfig(methods={
"forward": None, # base model
"lora_forward": lora_config, # LoRA variant
})
MultimethodConfig(methods=[
MethodConfig("forward", lora_config=None), # base model
MethodConfig("lora_forward", lora_config=lora_config), # LoRA variant
])
"""

methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict)
methods: List[MethodConfig] = field(default_factory=list)
share_mutable_buffers: bool = False

@property
def enabled(self) -> bool:
"""Returns True if multimethod_lora export is configured."""
"""Returns True if multimethod export is configured."""
return len(self.methods) > 0


Expand Down Expand Up @@ -611,9 +626,7 @@
model: ModelConfig = field(default_factory=ModelConfig)
export: ExportConfig = field(default_factory=ExportConfig)
debug: DebugConfig = field(default_factory=DebugConfig)
multimethod_lora: MultimethodLoraConfig = field(
default_factory=MultimethodLoraConfig
)
multimethod: MultimethodConfig = field(default_factory=MultimethodConfig)
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
backend: BackendConfig = field(default_factory=BackendConfig)

Expand Down
Loading