diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 504d043811e..7d6371add44 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -897,15 +897,15 @@ def _validate_args(llm_config): "Shared embedding is only supported with torchao quantization." ) - if llm_config.multimethod_lora.enabled: + if llm_config.multimethod.enabled: if llm_config.base.lora_config is not None: raise ValueError( - "Cannot use both base.lora_config and multimethod_lora.methods. " - "Use multimethod_lora.methods for all LoRA variants." + "Cannot use both base.lora_config and multimethod.methods. " + "Use multimethod.methods for all LoRA variants." ) if llm_config.quantization.pt2e_quantize is not None: raise ValueError( - "PT2E quantization is not supported with multimethod_lora export." + "PT2E quantization is not supported with multimethod export." ) if ( llm_config.backend.coreml.enabled @@ -915,7 +915,7 @@ def _validate_args(llm_config): or llm_config.backend.openvino.enabled ): raise ValueError( - "multimethod_lora export only supports XNNPACK backend or portable ops" + "multimethod export only supports XNNPACK backend or portable ops. " "Please disable other backends (coreml, vulkan, qnn, mps, openvino)." ) @@ -1230,7 +1230,7 @@ def _to_edge_and_lower_llama( # noqa: C901 def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List[Partitioner]]: - """Get XNNPACK partitioners for multimethod_lora export.""" + """Get XNNPACK partitioners for multimethod export.""" partitioners = [] # Order matters here, dynamic quantization should be applied first when @@ -1268,20 +1268,20 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: """ Export multiple methods (base + LoRA variants) to a single .pte file. - For each method in llm_config.multimethod_lora.methods: + For each method in llm_config.multimethod.methods: - If LoraConfig is None: use base model - If LoraConfig is provided: create model with LoRA weights Limitations: - - Only XNNPACK backend is supported for multimethod_lora export. + - Only XNNPACK backend is supported for multimethod export. - PT2E quantization is not supported. - Each method is exported separately; export time scales linearly with the number of methods. - The final .pte file deduplicates shared weights automatically. """ - num_methods = len(llm_config.multimethod_lora.methods) + num_methods = len(llm_config.multimethod.methods) logging.info( - f"multimethod_lora export: exporting {num_methods} method(s). " + f"multimethod export: exporting {num_methods} method(s). " "Each method requires separate model instantiation and export." ) @@ -1293,14 +1293,14 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: method_to_program: Dict[str, ExportedProgram] = {} first_builder = None - for method_name, lora_config in llm_config.multimethod_lora.methods.items(): - logging.info(f"Exporting method: {method_name}") + for method in llm_config.multimethod.methods: + logging.info(f"Exporting method: {method.method_name}") # Create a copy of config with this method's LoRA setting method_config = copy.deepcopy(llm_config) - method_config.base.lora_config = lora_config - # Disable multimethod_lora to avoid infinite recursion - method_config.multimethod_lora.methods = {} + method_config.base.lora_config = method.lora_config + # Disable multimethod to avoid infinite recursion + method_config.multimethod.methods = [] # Load and prepare model for this method builder = _prepare_for_llama_export(method_config) @@ -1309,7 +1309,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: # Get the exported program exported_program = builder._export(builder.pre_autograd_graph_module) - method_to_program[method_name] = exported_program + method_to_program[method.method_name] = exported_program if first_builder is None: first_builder = builder @@ -1319,7 +1319,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: # Get partitioners based on backend config partitioners = _get_xnnpack_partitioners(llm_config) - # Lower all methods together using multimethod_lora API + # Lower all methods together using multimethod API edge_config = first_builder._get_edge_config() edge_manager = to_edge_transform_and_lower( method_to_program, @@ -1333,7 +1333,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: first_builder.edge_manager = edge_manager first_builder = first_builder.to_executorch( passes=additional_passes, - share_mutable_buffers=llm_config.multimethod_lora.share_mutable_buffers, + share_mutable_buffers=llm_config.multimethod.share_mutable_buffers, ) output_file = _get_output_filename( @@ -1350,8 +1350,8 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 _validate_args(llm_config) - # Check for multimethod_lora export - if llm_config.multimethod_lora.enabled: + # Check for multimethod export + if llm_config.multimethod.enabled: return _export_llama_multimethod(llm_config) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( diff --git a/examples/models/qwen3/config/qwen3_multimethod.yaml b/examples/models/qwen3/config/qwen3_multimethod.yaml index 066738aead3..7325919b762 100644 --- a/examples/models/qwen3/config/qwen3_multimethod.yaml +++ b/examples/models/qwen3/config/qwen3_multimethod.yaml @@ -22,12 +22,13 @@ quantization: qmode: "8da4w" group_size: 32 -multimethod_lora: +multimethod: methods: # LoRA method - adapter paths from environment variables - lora_forward: - adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT} - adapter_config: ${oc.env:LORA_ADAPTER_CONFIG} + - method_name: lora_forward + lora_config: + adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT} + adapter_config: ${oc.env:LORA_ADAPTER_CONFIG} # Base method - no LoRA - base_forward: null + - method_name: base_forward share_mutable_buffers: True diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 47ad2f4374a..2cc14b07761 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -22,7 +22,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, Dict, List, Optional +from typing import ClassVar, List, Optional ################################################################################ @@ -293,37 +293,52 @@ class DebugConfig: ################################################################################ -############################## MultimethodLoraConfig ########################### +############################## MultimethodConfig ########################### ################################################################################ @dataclass -class MultimethodLoraConfig: +class MethodConfig: + """Configuration for exporting a single method to a .pte file. + By default, all other fields fall back to the default configs in + the yaml file. + + Attributes: + method_name: Name of the method to export. + lora_config: Optional LoRA configuration. + """ + + method_name: str + lora_config: Optional[LoraConfig] = None + + +@dataclass +class MultimethodConfig: """Configuration for exporting multiple methods to a single .pte file. - Maps method names to optional LoRA configurations. A None value means - the method uses base model weights. + Holds a list of method configs, as well as global options that apply + across all methods. Attributes: - methods: Dict mapping method names to optional LoRA configs. - Empty dict disables multimethod_lora export. + methods: List of MethodConfig objects with method name and config + for each method. share_mutable_buffers: Whether to share mutable buffers across methods. If True, sets all mutable buffers to mem_id=2. Mutable buffers with the same FQN (fully qualified name) will have the same offset. Example: - MultimethodLoraConfig(methods={ - "forward": None, # base model - "lora_forward": lora_config, # LoRA variant - }) + MultimethodConfig(methods=[ + MethodConfig("forward", lora_config=None), # base model + MethodConfig("lora_forward", lora_config=lora_config), # LoRA variant + ]) """ - methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict) + methods: List[MethodConfig] = field(default_factory=list) share_mutable_buffers: bool = False @property def enabled(self) -> bool: - """Returns True if multimethod_lora export is configured.""" + """Returns True if multimethod export is configured.""" return len(self.methods) > 0 @@ -611,9 +626,7 @@ class LlmConfig: model: ModelConfig = field(default_factory=ModelConfig) export: ExportConfig = field(default_factory=ExportConfig) debug: DebugConfig = field(default_factory=DebugConfig) - multimethod_lora: MultimethodLoraConfig = field( - default_factory=MultimethodLoraConfig - ) + multimethod: MultimethodConfig = field(default_factory=MultimethodConfig) quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig)