diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index a40198ea36f..e995e5e88f0 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -489,6 +489,12 @@ jobs: name: "gemma3-1b" use-custom: [false, true] qconfig: ["4w", "nvfp4"] + include: + - model: + id: "google/gemma-4-E2B-it" + name: "gemma4-e2b" + use-custom: true + qconfig: "4w" uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit with: @@ -506,12 +512,21 @@ jobs: MODEL_NAME="${{ matrix.model.name }}" USE_CUSTOM="${{ matrix.use-custom }}" QCONFIG="${{ matrix.qconfig }}" + MODEL_REVISION="" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" + fi CUSTOM_ARGS="" if [ "${USE_CUSTOM}" = "true" ]; then CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" fi + QEMBEDDING_ARGS="--qembedding ${QCONFIG}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + QEMBEDDING_ARGS="" + fi + echo "::group::Install ExecuTorch and configure MLX build" ${CONDA_RUN} python install_executorch.py > /dev/null ${CONDA_RUN} cmake --preset mlx-release @@ -522,6 +537,13 @@ jobs: ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + # Gemma 4 requires a newer Transformers build than the CI-wide + # optimum-executorch pin currently brings in. Keep this pinned to the + # locally validated commit instead of floating on Transformers HEAD. + GEMMA4_TRANSFORMERS_COMMIT=61461a7bcb458db7cf6eeea49678b9ab776a7821 + ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@${GEMMA4_TRANSFORMERS_COMMIT}" + fi echo "::endgroup::" ${CONDA_RUN} pip list @@ -529,9 +551,10 @@ jobs: echo "::group::Export ${MODEL_NAME}" ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "${MODEL_ID}" \ + ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --output /tmp/${MODEL_NAME}.pte \ --qlinear ${QCONFIG} \ - --qembedding ${QCONFIG} \ + ${QEMBEDDING_ARGS} \ ${CUSTOM_ARGS} echo "::endgroup::" @@ -539,6 +562,7 @@ jobs: OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte /tmp/${MODEL_NAME}.pte \ --model-id "${MODEL_ID}" \ + ${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \ --prompt "What is the capital of France?" \ --max-new-tokens 50 2>&1) echo "$OUTPUT" diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 0892476fedd..21ae8a3b6fa 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901 else: raise NotImplementedError(f"Support for input {arg} is not implemented") + placeholder_nodes = { + node.name: node for node in self.ep.graph.nodes if node.op == "placeholder" + } + + # Allocate placeholder-backed slots in graph-signature order instead of + # raw FX node traversal order. This keeps lifted constant tids stable + # across equivalent exports, which matters for models like Gemma 4 that + # carry multiple rotary constant placeholders with similar structure. + for name in constant_tensors: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + self.make_or_get_slot(node, id_space=IdSpace.Constant) + + for name in user_inputs: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + + for name in mutable_buffers: + node = placeholder_nodes.get(name) + if node is None or node.users == {}: + continue + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + + classified_placeholders = ( + set(constant_tensors) | set(user_inputs) | set(mutable_buffers) + ) + for node in self.ep.graph.nodes: if node.op == "placeholder": if node.users == {}: continue - if node.name in constant_tensors: - self.make_or_get_slot(node, id_space=IdSpace.Constant) - elif node.name in user_inputs: - val = node.meta.get("val", None) - if isinstance(val, torch.Tensor) and not val.is_contiguous(): - raise ValueError( - f"MLX backend requires contiguous input tensors, " - f"but input '{node.name}' has non-contiguous strides. " - f"shape={list(val.shape)}, stride={list(val.stride())}. " - f"Ensure example inputs passed to torch.export.export() " - f"are contiguous (call .contiguous() on them)." - ) - self.make_or_get_slot(node, id_space=IdSpace.Input) - elif node.name in mutable_buffers: - self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) - else: + if node.name not in classified_placeholders: raise NotImplementedError( f"Support for placeholder {node.name} is not implemented" ) diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index f860c4f1ce0..8def8c1f06a 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for - **KV Cache**: Efficient KV cache implementation for autoregressive generation - **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX - **Pybindings**: Run inference using ExecuTorch Python bindings +- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it` ## Requirements @@ -52,6 +53,25 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --use-custom-kv-cache \ --qlinear 4w \ --qembedding 4w + +# Gemma 4 text-only export +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "google/gemma-4-E2B-it" \ + --revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \ + --output gemma4_hf_int4.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache \ + --qlinear 4w +``` + +Gemma 4 support is currently validated for the text-only path using +`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`. + +Validated with `transformers` commit +`61461a7bcb458db7cf6eeea49678b9ab776a7821`: + +```bash +pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@61461a7bcb458db7cf6eeea49678b9ab776a7821" ``` ### Options @@ -81,12 +101,25 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --prompt "Explain quantum computing in simple terms" ``` +Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts. + +Validated Gemma 4 run command: + +```bash +python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte gemma4_hf_int4.pte \ + --model-id google/gemma-4-E2B-it \ + --revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 +``` + ### Options | Option | Default | Description | |--------|---------|-------------| | `--pte` | `llama_hf.pte` | Path to .pte file | -| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) | +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) | | `--prompt` | `The quick brown fox` | Input prompt | | `--max-new-tokens` | `50` | Maximum tokens to generate | diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 39f13e434be..3ba483142c5 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -50,6 +50,7 @@ def _export_with_optimum( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -73,6 +74,7 @@ def _export_with_optimum( logger.info(f"Loading model using optimum-executorch: {model_id}") exportable = load_causal_lm_model( model_id, + revision=revision, dtype=dtype_str, max_seq_len=max_seq_len, ) @@ -124,6 +126,7 @@ def _export_with_optimum( def _export_with_custom_components( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -166,20 +169,21 @@ def _export_with_custom_components( attn_implementation = "mlx" if use_custom_sdpa else None - # Detect sliding window models (e.g., gemma) - sliding_window = None - logger.info(f"Loading HuggingFace model: {model_id}") load_kwargs = { "torch_dtype": torch_dtype, "low_cpu_mem_usage": True, } + if revision is not None: + load_kwargs["revision"] = revision if attn_implementation: load_kwargs["attn_implementation"] = attn_implementation model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) - # Check if model uses sliding window attention - sliding_window = getattr(model.config, "sliding_window", None) + # Check if model uses sliding window attention. Multimodal configs like + # Gemma 4 keep transformer attributes under text_config. + text_config = model.config.get_text_config() + sliding_window = getattr(text_config, "sliding_window", None) if sliding_window is not None: logger.info(f"Model has sliding_window={sliding_window}") # Cap max_seq_len to sliding window size for cache allocation @@ -188,11 +192,16 @@ def _export_with_custom_components( else: effective_cache_len = max_seq_len + # The HF ExecuTorch cache wrappers validate both generation_config.use_cache + # and the text config's use_cache flag before constructing static caches. + model.generation_config.use_cache = True model.generation_config.cache_implementation = "static" model.generation_config.cache_config = { "batch_size": 1, "max_cache_len": effective_cache_len, } + text_config = model.config.get_text_config() + text_config.use_cache = True model.eval() # Use HybridCache wrapper for sliding window models (stores cache as .cache), @@ -341,6 +350,7 @@ def _save_program(executorch_program, output_path: str) -> None: def export_llama_hf( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int = 1024, dtype: str = "bf16", @@ -372,6 +382,7 @@ def export_llama_hf( ) _export_with_custom_components( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -387,6 +398,7 @@ def export_llama_hf( logger.info("Using optimum-executorch pipeline (no custom components)") _export_with_optimum( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -408,6 +420,12 @@ def main(): default="unsloth/Llama-3.2-1B-Instruct", help="HuggingFace model ID", ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", + ) parser.add_argument( "--output", type=str, @@ -447,6 +465,7 @@ def main(): export_llama_hf( model_id=args.model_id, + revision=args.revision, output_path=args.output, max_seq_len=args.max_seq_len, dtype=args.dtype, diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index ca3d0468114..c15bcd89c46 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -7,10 +7,11 @@ # LICENSE file in the root directory of this source tree. """ -Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. +Run exported HuggingFace LLM using ExecuTorch pybindings. This script runs models exported using export_llm_hf.py. It loads the tokenizer -directly from HuggingFace using the same model ID used during export. +or processor directly from HuggingFace using the same model ID used during +export. Usage: python -m executorch.backends.mlx.examples.llm.run_llm_hf \ @@ -25,7 +26,7 @@ import torch from executorch.runtime import Runtime, Verification -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -46,15 +47,66 @@ def _get_max_input_seq_len(program) -> int: return sizes[1] if len(sizes) >= 2 else 1 +def _load_text_processor(model_id: str, revision: str | None): + """ + Load a text processor for the model. + + Prefer AutoTokenizer for text-only prompting, even for checkpoints that + also ship an AutoProcessor. Some hybrid checkpoints (for example Gemma 4) + expose both, but the tokenizer path is the more stable interface for the + plain text generation flow exercised by this runner. + """ + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + return tokenizer, False + except Exception as exc: + logger.info(f"AutoTokenizer unavailable for {model_id}: {exc}") + + try: + processor = AutoProcessor.from_pretrained(model_id, revision=revision) + if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): + logger.info(f"Loaded processor from HuggingFace: {model_id}") + return processor, True + except Exception as exc: + logger.info(f"AutoProcessor unavailable for {model_id}: {exc}") + + raise RuntimeError(f"Could not load tokenizer or processor for {model_id}") + + +def _apply_chat_template(text_processor, messages) -> str: + try: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + +def _get_eos_token_id(text_processor): + eos_token_id = getattr(text_processor, "eos_token_id", None) + if eos_token_id is not None: + return eos_token_id + tokenizer = getattr(text_processor, "tokenizer", None) + return getattr(tokenizer, "eos_token_id", None) + + def run_inference( pte_path: str, model_id: str, + revision: str | None, prompt: str, max_new_tokens: int = 50, ) -> str: """Run inference on the exported HuggingFace model.""" - logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") - tokenizer = AutoTokenizer.from_pretrained(model_id) + text_processor, uses_processor = _load_text_processor(model_id, revision) logger.info(f"Loading model from {pte_path}...") et_runtime = Runtime.get() @@ -67,14 +119,18 @@ def run_inference( logger.info(f"Encoding prompt: {prompt!r}") messages = [{"role": "user", "content": prompt}] - formatted_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") + formatted_prompt = _apply_chat_template(text_processor, messages) + if uses_processor: + input_ids = text_processor(text=formatted_prompt, return_tensors="pt")[ + "input_ids" + ] + else: + input_ids = text_processor.encode(formatted_prompt, return_tensors="pt") logger.info(f"Input shape: {input_ids.shape}") generated_tokens = input_ids[0].tolist() seq_len = input_ids.shape[1] + eos_token_id = _get_eos_token_id(text_processor) start_time = time.time() @@ -120,7 +176,7 @@ def run_inference( next_token = torch.argmax(next_token_logits).item() generated_tokens.append(next_token) - if next_token == tokenizer.eos_token_id: + if eos_token_id is not None and next_token == eos_token_id: logger.info(f"EOS token reached at position {i + 1}") break @@ -135,12 +191,12 @@ def run_inference( # Decode only the newly generated tokens (not the input prompt) new_tokens = generated_tokens[seq_len:] - generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + generated_text = text_processor.decode(new_tokens, skip_special_tokens=True) return generated_text def main(): - parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") + parser = argparse.ArgumentParser(description="Run exported HuggingFace LLM") parser.add_argument( "--pte", type=str, @@ -151,7 +207,13 @@ def main(): "--model-id", type=str, default="unsloth/Llama-3.2-1B-Instruct", - help="HuggingFace model ID (used to load tokenizer)", + help="HuggingFace model ID (used to load tokenizer or processor)", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", ) parser.add_argument( "--prompt", @@ -171,6 +233,7 @@ def main(): generated_text = run_inference( pte_path=args.pte, model_id=args.model_id, + revision=args.revision, prompt=args.prompt, max_new_tokens=args.max_new_tokens, ) diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index 9709980689b..2890e823499 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -12,6 +12,7 @@ Provides reusable KV cache implementations optimized for the MLX backend: """ +import inspect from typing import Tuple import torch @@ -21,6 +22,62 @@ from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 +def resolve_hf_text_config(config): + """Return the text config for multimodal HF models, or the config itself.""" + if hasattr(config, "get_text_config"): + return config.get_text_config() + return getattr(config, "text_config", config) + + +def resolve_hf_cache_layout(config): + """ + Return per-cache-layer metadata for HuggingFace hybrid/static caches. + + Some models such as Gemma 4 use different KV geometries depending on the + attention layer type. Match the upstream `transformers` hybrid cache layout + so our replacement cache allocates the same number of layers with the same + `(num_heads, head_dim)` for each backing cache entry. + """ + text_config = resolve_hf_text_config(config) + layer_types = getattr(text_config, "layer_types", None) + + if layer_types is None: + if getattr(text_config, "sliding_window", None) is not None: + layer_types = ["sliding_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = ["full_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = list(layer_types) + + if hasattr(text_config, "num_kv_shared_layers"): + layer_types = layer_types[: -text_config.num_kv_shared_layers] + + if hasattr(text_config, "global_head_dim"): + head_dims = [ + text_config.global_head_dim if layer_type == "full_attention" else text_config.head_dim + for layer_type in layer_types + ] + num_heads = [ + text_config.num_global_key_value_heads + if layer_type == "full_attention" and getattr(text_config, "attention_k_eq_v", False) + else text_config.num_key_value_heads + for layer_type in layer_types + ] + else: + head_dim = getattr( + text_config, + "head_dim", + text_config.hidden_size // text_config.num_attention_heads, + ) + num_head = getattr( + text_config, "num_key_value_heads", text_config.num_attention_heads + ) + head_dims = [head_dim for _ in layer_types] + num_heads = [num_head for _ in layer_types] + + return layer_types, num_heads, head_dims + + class KVCache(nn.Module): """ MLX-optimized KV cache with ExecutorTorch llama KVCache interface. @@ -326,14 +383,13 @@ def __init__( device: Device for cache tensors (default: None = CPU) dtype: Data type for cache tensors (default: torch.float32) """ - # Resolve dimensions from config BEFORE calling parent - num_layers = config.num_hidden_layers - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) + # Resolve dimensions from the text config before calling parent. Multimodal + # configs like Gemma 4 expose transformer dims under text_config. + text_config = resolve_hf_text_config(config) + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_model_layers = text_config.num_hidden_layers actual_max_cache_len = max_cache_len or getattr( - config, "max_position_embeddings", 2048 + text_config, "max_position_embeddings", 2048 ) # Initialize parent StaticCache with required arguments @@ -344,19 +400,50 @@ def __init__( device=device, dtype=dtype, ) - # Call early_initialization to ensure parent's layers are fully initialized - self.early_initialization( - batch_size=max_batch_size, - num_heads=num_heads, - head_dim=head_dim, - dtype=dtype, - device=device, - ) + # Newer HF cache implementations already support per-layer layouts in + # early_initialization(). Keep that path for Gemma 4, and only fall + # back to manual layer initialization for the older CI-pinned API. + try: + self.early_initialization( + batch_size=max_batch_size, + num_heads=num_heads, + head_dim=head_dims, + dtype=dtype, + device=device, + ) + except TypeError: + for layer, layer_num_heads, layer_head_dim in zip( + self.layers, num_heads, head_dims + ): + fake_keys_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + lazy_init_sig = inspect.signature(layer.lazy_initialization) + # Older pinned HF caches take a single fake tensor, while newer + # versions expect both key_states and value_states separately. + if len(lazy_init_sig.parameters) == 1: + layer.lazy_initialization(fake_keys_tensor) + else: + fake_values_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + layer.lazy_initialization(fake_keys_tensor, fake_values_tensor) + + # Some models (for example Gemma 4) only allocate cache entries for the + # non-shared KV layers. Mirror the parent StaticCache layout exactly so + # layer_idx values passed to update() line up with our backing cache. + num_cache_layers = len(self.layers) # Store dimensions as instance attributes - self.num_layers = num_layers + self.num_model_layers = num_model_layers + self.num_layers = num_cache_layers + self.layer_types = layer_types self.num_heads = num_heads - self.head_dim = head_dim + self.head_dim = head_dims # Create KVCache wrappers for each layer - these use mlx::kv_cache_update # Named 'kv_cache' to match optimum-executorch's ETCustomStaticCache pattern @@ -365,12 +452,12 @@ def __init__( KVCache( max_batch_size=max_batch_size, max_context_length=actual_max_cache_len, - n_heads=num_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, enable_dynamic_shape=True, dtype=dtype, ) - for _ in range(num_layers) + for layer_num_heads, layer_head_dim in zip(num_heads, head_dims) ] ) @@ -394,18 +481,31 @@ def update( key_states: New key states [batch_size, num_heads, seq_len, head_dim] value_states: New value states [batch_size, num_heads, seq_len, head_dim] layer_idx: Index of the layer to update - cache_kwargs: Dictionary containing 'cache_position' tensor with start position + cache_kwargs: Optional dictionary containing 'cache_position' tensor + with start position. Newer HF StaticCache callers seed + `self.layers[layer_idx].cumulative_length` directly and do not + pass cache_kwargs. Returns: Tuple of (key_cache, value_cache) for the full cache after update """ - assert ( - cache_kwargs is not None - ), "cache_kwargs must be provided with 'cache_position'" - cache_position = cache_kwargs.get("cache_position") - assert ( - cache_position is not None - ), "cache_position must be provided in cache_kwargs" + if cache_kwargs is not None: + cache_position = cache_kwargs.get("cache_position") + else: + cache_position = None + + if cache_position is None: + # Current HF ExecuTorch wrappers copy the requested cache position + # into each StaticCache layer's cumulative_length before forward(). + if hasattr(self.layers[layer_idx], "cumulative_length"): + cache_position = self.layers[layer_idx].cumulative_length + else: + raise RuntimeError( + "cache_position was not provided and the pinned " + "transformers StaticCache layer does not expose " + "cumulative_length" + ) + assert isinstance( cache_position, torch.Tensor ), "cache_position must be a tensor" diff --git a/backends/mlx/llm/hf_attention.py b/backends/mlx/llm/hf_attention.py index 9e3c864dce6..f2a01c9e653 100644 --- a/backends/mlx/llm/hf_attention.py +++ b/backends/mlx/llm/hf_attention.py @@ -89,8 +89,10 @@ def mlx_sdpa_with_start_pos_forward( def sdpa_mask_passthrough( batch_size: int, - cache_position: torch.Tensor, - kv_length: int, + cache_position: Optional[torch.Tensor] = None, + q_length: Optional[int] = None, + kv_length: Optional[int] = None, + q_offset: Optional[Union[int, torch.Tensor]] = None, kv_offset: int = 0, mask_function: Optional[Callable] = None, attention_mask: Optional[torch.Tensor] = None, @@ -139,6 +141,27 @@ def get_mlx_sliding_window_sdpa(exportable_module) -> Callable: Attention function compatible with HuggingFace's attention interface. """ + def _resolve_cache_layer_idx(module: torch.nn.Module, cache) -> Optional[int]: + """ + Map a transformer layer index to the backing cache slot index. + + Hybrid/shared-KV models like Gemma 4 only allocate cache entries for the + non-shared KV layers. Shared layers expose `kv_shared_layer_index`, which + points at the earlier cache-producing layer they reuse. + """ + layer_idx = getattr(module, "layer_idx", None) + if layer_idx is None: + return None + + if layer_idx < len(cache.kv_cache): + return layer_idx + + shared_layer_idx = getattr(module, "kv_shared_layer_index", None) + if shared_layer_idx is not None and shared_layer_idx < len(cache.kv_cache): + return shared_layer_idx + + return None + def _sliding_window_sdpa_forward( module: torch.nn.Module, query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD @@ -165,6 +188,7 @@ def _sliding_window_sdpa_forward( attn_mask = None start_pos = 0 + layer_cache = None if layer_idx is not None and position_ids is not None: start_pos = position_ids[0][0].item() @@ -173,7 +197,9 @@ def _sliding_window_sdpa_forward( cache = getattr(exportable_module, "cache", None) if cache is not None: - layer_cache = cache.kv_cache[layer_idx] + cache_layer_idx = _resolve_cache_layer_idx(module, cache) + if cache_layer_idx is not None: + layer_cache = cache.kv_cache[cache_layer_idx] if isinstance(layer_cache, RingBufferKVCache): attn_mask = layer_cache.create_sliding_window_mask( start_pos, seq_len @@ -182,11 +208,19 @@ def _sliding_window_sdpa_forward( # stop_pos = start_pos + seq_len = buffer_size start_pos = layer_cache.buffer_size - seq_len + # Hybrid models use one global HF attention implementation. Sliding + # layers need the ring-buffer mask path, while full-attention layers + # should keep the regular causal SDPA path even under the same hook. if attn_mask is None: - raise RuntimeError( - f"Sliding window attention at layer {layer_idx} requires a " - f"RingBufferKVCache, but none was found. Ensure the model's " - f"cache is set up with RingBufferKVCache for sliding window layers." + return mlx_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + position_ids=position_ids, + scaling=scaling, + **kwargs, ) output = torch.ops.mlx.custom_sdpa( diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py index d90073c633e..06a45b9e22b 100644 --- a/backends/mlx/llm/source_transformation.py +++ b/backends/mlx/llm/source_transformation.py @@ -19,7 +19,13 @@ import torch import torch.nn as nn -from executorch.backends.mlx.llm.cache import HFStaticCache, KVCache, RingBufferKVCache +from executorch.backends.mlx.llm.cache import ( + HFStaticCache, + KVCache, + RingBufferKVCache, + resolve_hf_cache_layout, + resolve_hf_text_config, +) logger = logging.getLogger(__name__) @@ -123,9 +129,17 @@ def replace_hf_cache_with_mlx( def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -171,12 +185,6 @@ def replace_hf_cache_with_mlx_ring_buffer( """ from transformers.cache_utils import StaticCache - num_layers = config.num_hidden_layers - num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - # Create HFStaticCache with ring buffer layers mlx_cache = HFStaticCache( config=config, @@ -185,22 +193,39 @@ def replace_hf_cache_with_mlx_ring_buffer( dtype=dtype, ) - # Replace each layer's KVCache with RingBufferKVCache - for i in range(num_layers): - ring_cache = RingBufferKVCache( + # Replace only the sliding-window cache entries with ring buffers, while + # preserving full-attention entries as linear caches. Hybrid models like + # Gemma 4 mix both layouts and can also vary head_dim per cache layer. + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_cache_layers = len(mlx_cache.layers) + num_ring_layers = 0 + for i, (layer_type, layer_num_heads, layer_head_dim) in enumerate( + zip(layer_types, num_heads, head_dims) + ): + if layer_type != "sliding_attention": + continue + mlx_cache.kv_cache[i] = RingBufferKVCache( max_batch_size=max_batch_size, max_context_length=window_size, - n_heads=num_kv_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, dtype=dtype, ) - mlx_cache.kv_cache[i] = ring_cache + num_ring_layers += 1 def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -218,8 +243,8 @@ def _install_cache(attr_name): raise ValueError("Module must have 'static_cache' or 'cache' attribute") logger.info( - f"Installed RingBufferKVCache: {num_layers} layers, " - f"window_size={window_size}, heads={num_kv_heads}, head_dim={head_dim}" + f"Installed hybrid MLX cache: {num_ring_layers} ring-buffer layers / " + f"{num_cache_layers} total cache layers, window_size={window_size}" ) return module diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 99e20114ea7..5bd3bf263d1 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -285,6 +284,13 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { processed->Free(); } return Error::InvalidProgram; + } catch (...) { + ET_LOG(Error, "Failed to load MLX program: unknown non-std exception"); + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; } return handle; @@ -416,6 +422,9 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } catch (const std::exception& e) { ET_LOG(Error, "MLX execute failed: %s", e.what()); return Error::Internal; + } catch (...) { + ET_LOG(Error, "MLX execute failed: unknown non-std exception"); + return Error::Internal; } }