-
Notifications
You must be signed in to change notification settings - Fork 971
Add Gemma 4 MLX install-path support #19065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
5f455a2
0a822bd
fd78741
0e00290
3a26baa
0bf5fc4
90e5577
ee272c3
ca37250
818a51d
6e520dd
391cde4
19d6f09
41e3a51
9d3f841
719d2e8
065b50e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -166,9 +166,6 @@ def _export_with_custom_components( | |
|
|
||
| attn_implementation = "mlx" if use_custom_sdpa else None | ||
|
|
||
| # Detect sliding window models (e.g., gemma) | ||
| sliding_window = None | ||
|
|
||
| logger.info(f"Loading HuggingFace model: {model_id}") | ||
| load_kwargs = { | ||
| "torch_dtype": torch_dtype, | ||
|
|
@@ -178,8 +175,10 @@ def _export_with_custom_components( | |
| load_kwargs["attn_implementation"] = attn_implementation | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) | ||
|
|
||
| # Check if model uses sliding window attention | ||
| sliding_window = getattr(model.config, "sliding_window", None) | ||
| # Check if model uses sliding window attention. Multimodal configs like | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this regress gemma3?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it would be great to try on gemma3 as a smoke test, that would be great. If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39) |
||
| # Gemma 4 keep transformer attributes under text_config. | ||
| text_config = model.config.get_text_config() | ||
| sliding_window = getattr(text_config, "sliding_window", None) | ||
| if sliding_window is not None: | ||
| logger.info(f"Model has sliding_window={sliding_window}") | ||
| # Cap max_seq_len to sliding window size for cache allocation | ||
|
|
@@ -188,11 +187,16 @@ def _export_with_custom_components( | |
| else: | ||
| effective_cache_len = max_seq_len | ||
|
|
||
| # The HF ExecuTorch cache wrappers validate both generation_config.use_cache | ||
| # and the text config's use_cache flag before constructing static caches. | ||
| model.generation_config.use_cache = True | ||
| model.generation_config.cache_implementation = "static" | ||
| model.generation_config.cache_config = { | ||
| "batch_size": 1, | ||
| "max_cache_len": effective_cache_len, | ||
| } | ||
| text_config = model.config.get_text_config() | ||
| text_config.use_cache = True | ||
| model.eval() | ||
|
|
||
| # Use HybridCache wrapper for sliding window models (stores cache as .cache), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,11 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. | ||
| Run exported HuggingFace LLM using ExecuTorch pybindings. | ||
|
|
||
| This script runs models exported using export_llm_hf.py. It loads the tokenizer | ||
| directly from HuggingFace using the same model ID used during export. | ||
| or processor directly from HuggingFace using the same model ID used during | ||
| export. | ||
|
|
||
| Usage: | ||
| python -m executorch.backends.mlx.examples.llm.run_llm_hf \ | ||
|
|
@@ -20,18 +21,89 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import ctypes | ||
| import logging | ||
| import os | ||
| import shutil | ||
| import time | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from executorch.runtime import Runtime, Verification | ||
| from transformers import AutoTokenizer | ||
| from transformers import AutoProcessor, AutoTokenizer | ||
|
|
||
| FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" | ||
| logging.basicConfig(level=logging.INFO, format=FORMAT) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _iter_mlx_backend_candidates(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code should not be needed. Did you do: on a mac machine with xcode installed? If so, in the install logs, did you see a comment about MLX installation being skipped for some reason? |
||
| env_path = os.environ.get("ET_MLX_BACKEND_DYLIB") | ||
| if env_path: | ||
| yield Path(env_path) | ||
|
|
||
| for parent in Path(__file__).resolve().parents: | ||
| pip_out = parent / "pip-out" | ||
| if pip_out.exists(): | ||
| yield from sorted( | ||
| pip_out.glob( | ||
| "temp.*/cmake-out/backends/mlx/libmlxdelegate_runtime.dylib" | ||
| ) | ||
| ) | ||
| yield from sorted( | ||
| pip_out.glob("temp.*/cmake-out/backends/mlx/libmlxdelegate.dylib") | ||
| ) | ||
| break | ||
|
|
||
|
|
||
| def _ensure_mlx_metallib(dylib_path: Path) -> None: | ||
| metallib_path = dylib_path.with_name("mlx.metallib") | ||
| if metallib_path.exists(): | ||
| return | ||
|
|
||
| for parent in Path(__file__).resolve().parents: | ||
| pip_out = parent / "pip-out" | ||
| if not pip_out.exists(): | ||
| continue | ||
| matches = sorted( | ||
| pip_out.glob( | ||
| "temp.*/cmake-out/backends/mlx/mlx/mlx/backend/metal/kernels/mlx.metallib" | ||
| ) | ||
| ) | ||
| if not matches: | ||
| continue | ||
| shutil.copyfile(matches[0], metallib_path) | ||
| logger.info(f"Copied MLX metallib next to runtime library: {metallib_path}") | ||
| return | ||
|
|
||
|
|
||
| def _ensure_mlx_backend_registered() -> Runtime: | ||
| runtime = Runtime.get() | ||
| if runtime.backend_registry.is_available("MLXBackend"): | ||
| return runtime | ||
|
|
||
| for candidate in _iter_mlx_backend_candidates(): | ||
| if not candidate.is_file(): | ||
| continue | ||
| try: | ||
| _ensure_mlx_metallib(candidate) | ||
| ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) | ||
| except OSError as exc: | ||
| logger.info(f"Failed to load MLX backend library {candidate}: {exc}") | ||
| continue | ||
|
|
||
| runtime = Runtime.get() | ||
| if runtime.backend_registry.is_available("MLXBackend"): | ||
| logger.info(f"Loaded MLX backend runtime library: {candidate}") | ||
| return runtime | ||
|
|
||
| logger.warning( | ||
| "MLXBackend is not registered. If you built mlxdelegate locally, " | ||
| "set ET_MLX_BACKEND_DYLIB to the path of libmlxdelegate_runtime.dylib." | ||
| ) | ||
| return runtime | ||
|
|
||
|
|
||
| def _get_max_input_seq_len(program) -> int: | ||
| """Inspect the .pte program metadata to determine the max input_ids seq len. | ||
|
|
||
|
|
@@ -46,18 +118,61 @@ def _get_max_input_seq_len(program) -> int: | |
| return sizes[1] if len(sizes) >= 2 else 1 | ||
|
|
||
|
|
||
| def _load_text_processor(model_id: str): | ||
| """ | ||
| Load a text processor for the model. | ||
|
|
||
| Prefer AutoProcessor for multimodal/text-hybrid models like Gemma 4, and | ||
| fall back to AutoTokenizer for text-only checkpoints. | ||
| """ | ||
| try: | ||
| processor = AutoProcessor.from_pretrained(model_id) | ||
| if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): | ||
| logger.info(f"Loaded processor from HuggingFace: {model_id}") | ||
| return processor, True | ||
| except Exception as exc: | ||
| logger.info(f"AutoProcessor unavailable for {model_id}: {exc}") | ||
|
|
||
| logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| return tokenizer, False | ||
|
|
||
|
|
||
| def _apply_chat_template(text_processor, messages) -> str: | ||
| try: | ||
| return text_processor.apply_chat_template( | ||
| messages, | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| enable_thinking=False, | ||
| ) | ||
| except TypeError: | ||
| return text_processor.apply_chat_template( | ||
| messages, | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| ) | ||
|
|
||
|
|
||
| def _get_eos_token_id(text_processor): | ||
| eos_token_id = getattr(text_processor, "eos_token_id", None) | ||
| if eos_token_id is not None: | ||
| return eos_token_id | ||
| tokenizer = getattr(text_processor, "tokenizer", None) | ||
| return getattr(tokenizer, "eos_token_id", None) | ||
|
|
||
|
|
||
| def run_inference( | ||
| pte_path: str, | ||
| model_id: str, | ||
| prompt: str, | ||
| max_new_tokens: int = 50, | ||
| ) -> str: | ||
| """Run inference on the exported HuggingFace model.""" | ||
| logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| text_processor, uses_processor = _load_text_processor(model_id) | ||
|
|
||
| logger.info(f"Loading model from {pte_path}...") | ||
| et_runtime = Runtime.get() | ||
| et_runtime = _ensure_mlx_backend_registered() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be needed, see comment on the install process.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That’s fair. I added this while debugging the installed-package path locally because |
||
| program = et_runtime.load_program(pte_path, verification=Verification.Minimal) | ||
|
|
||
| max_seq_len = _get_max_input_seq_len(program) | ||
|
|
@@ -67,14 +182,18 @@ def run_inference( | |
|
|
||
| logger.info(f"Encoding prompt: {prompt!r}") | ||
| messages = [{"role": "user", "content": prompt}] | ||
| formatted_prompt = tokenizer.apply_chat_template( | ||
| messages, tokenize=False, add_generation_prompt=True | ||
| ) | ||
| input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") | ||
| formatted_prompt = _apply_chat_template(text_processor, messages) | ||
| if uses_processor: | ||
| input_ids = text_processor(text=formatted_prompt, return_tensors="pt")[ | ||
| "input_ids" | ||
| ] | ||
| else: | ||
| input_ids = text_processor.encode(formatted_prompt, return_tensors="pt") | ||
| logger.info(f"Input shape: {input_ids.shape}") | ||
|
|
||
| generated_tokens = input_ids[0].tolist() | ||
| seq_len = input_ids.shape[1] | ||
| eos_token_id = _get_eos_token_id(text_processor) | ||
|
|
||
| start_time = time.time() | ||
|
|
||
|
|
@@ -120,7 +239,7 @@ def run_inference( | |
| next_token = torch.argmax(next_token_logits).item() | ||
| generated_tokens.append(next_token) | ||
|
|
||
| if next_token == tokenizer.eos_token_id: | ||
| if eos_token_id is not None and next_token == eos_token_id: | ||
| logger.info(f"EOS token reached at position {i + 1}") | ||
| break | ||
|
|
||
|
|
@@ -135,12 +254,12 @@ def run_inference( | |
|
|
||
| # Decode only the newly generated tokens (not the input prompt) | ||
| new_tokens = generated_tokens[seq_len:] | ||
| generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | ||
| generated_text = text_processor.decode(new_tokens, skip_special_tokens=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this break the path where uses_processor=False? Can we unify these two paths somehow?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ended up unifying this path. |
||
| return generated_text | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") | ||
| parser = argparse.ArgumentParser(description="Run exported HuggingFace LLM") | ||
| parser.add_argument( | ||
| "--pte", | ||
| type=str, | ||
|
|
@@ -151,7 +270,7 @@ def main(): | |
| "--model-id", | ||
| type=str, | ||
| default="unsloth/Llama-3.2-1B-Instruct", | ||
| help="HuggingFace model ID (used to load tokenizer)", | ||
| help="HuggingFace model ID (used to load tokenizer or processor)", | ||
| ) | ||
| parser.add_argument( | ||
| "--prompt", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no embeeding?