-
Notifications
You must be signed in to change notification settings - Fork 966
Add Gemma 4 MLX install-path support #19065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f455a2
0a822bd
fd78741
0e00290
3a26baa
0bf5fc4
90e5577
ee272c3
ca37250
818a51d
6e520dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901 | |
| else: | ||
| raise NotImplementedError(f"Support for input {arg} is not implemented") | ||
|
|
||
| placeholder_nodes = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow this change. Why is gemma4 sensistive to this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got here by diffing a previously working Gemma 4 What changed there was the slot assignment for the two rotary constants used by sliding-window vs full attention. This change was just to make that assignment deterministic instead of depending on raw placeholder traversal order. Gemma 4 is where I noticed it because that model exercises both constants in the same path. If you’d prefer, I can drop this |
||
| node.name: node for node in self.ep.graph.nodes if node.op == "placeholder" | ||
| } | ||
|
|
||
| # Allocate placeholder-backed slots in graph-signature order instead of | ||
| # raw FX node traversal order. This keeps lifted constant tids stable | ||
| # across equivalent exports, which matters for models like Gemma 4 that | ||
| # carry multiple rotary constant placeholders with similar structure. | ||
| for name in constant_tensors: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| self.make_or_get_slot(node, id_space=IdSpace.Constant) | ||
|
|
||
| for name in user_inputs: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| val = node.meta.get("val", None) | ||
| if isinstance(val, torch.Tensor) and not val.is_contiguous(): | ||
| raise ValueError( | ||
| f"MLX backend requires contiguous input tensors, " | ||
| f"but input '{node.name}' has non-contiguous strides. " | ||
| f"shape={list(val.shape)}, stride={list(val.stride())}. " | ||
| f"Ensure example inputs passed to torch.export.export() " | ||
| f"are contiguous (call .contiguous() on them)." | ||
| ) | ||
| self.make_or_get_slot(node, id_space=IdSpace.Input) | ||
|
|
||
| for name in mutable_buffers: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) | ||
|
|
||
| classified_placeholders = ( | ||
| set(constant_tensors) | set(user_inputs) | set(mutable_buffers) | ||
| ) | ||
|
|
||
| for node in self.ep.graph.nodes: | ||
| if node.op == "placeholder": | ||
| if node.users == {}: | ||
| continue | ||
| if node.name in constant_tensors: | ||
| self.make_or_get_slot(node, id_space=IdSpace.Constant) | ||
| elif node.name in user_inputs: | ||
| val = node.meta.get("val", None) | ||
| if isinstance(val, torch.Tensor) and not val.is_contiguous(): | ||
| raise ValueError( | ||
| f"MLX backend requires contiguous input tensors, " | ||
| f"but input '{node.name}' has non-contiguous strides. " | ||
| f"shape={list(val.shape)}, stride={list(val.stride())}. " | ||
| f"Ensure example inputs passed to torch.export.export() " | ||
| f"are contiguous (call .contiguous() on them)." | ||
| ) | ||
| self.make_or_get_slot(node, id_space=IdSpace.Input) | ||
| elif node.name in mutable_buffers: | ||
| self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) | ||
| else: | ||
| if node.name not in classified_placeholders: | ||
| raise NotImplementedError( | ||
| f"Support for placeholder {node.name} is not implemented" | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ | |
|
|
||
| def _export_with_optimum( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int, | ||
| dtype: str, | ||
|
|
@@ -73,6 +74,7 @@ def _export_with_optimum( | |
| logger.info(f"Loading model using optimum-executorch: {model_id}") | ||
| exportable = load_causal_lm_model( | ||
| model_id, | ||
| revision=revision, | ||
| dtype=dtype_str, | ||
| max_seq_len=max_seq_len, | ||
| ) | ||
|
|
@@ -124,6 +126,7 @@ def _export_with_optimum( | |
|
|
||
| def _export_with_custom_components( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int, | ||
| dtype: str, | ||
|
|
@@ -166,20 +169,21 @@ def _export_with_custom_components( | |
|
|
||
| attn_implementation = "mlx" if use_custom_sdpa else None | ||
|
|
||
| # Detect sliding window models (e.g., gemma) | ||
| sliding_window = None | ||
|
|
||
| logger.info(f"Loading HuggingFace model: {model_id}") | ||
| load_kwargs = { | ||
| "torch_dtype": torch_dtype, | ||
| "low_cpu_mem_usage": True, | ||
| } | ||
| if revision is not None: | ||
| load_kwargs["revision"] = revision | ||
| if attn_implementation: | ||
| load_kwargs["attn_implementation"] = attn_implementation | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) | ||
|
|
||
| # Check if model uses sliding window attention | ||
| sliding_window = getattr(model.config, "sliding_window", None) | ||
| # Check if model uses sliding window attention. Multimodal configs like | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this regress gemma3?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it would be great to try on gemma3 as a smoke test, that would be great. If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39) |
||
| # Gemma 4 keep transformer attributes under text_config. | ||
| text_config = model.config.get_text_config() | ||
| sliding_window = getattr(text_config, "sliding_window", None) | ||
| if sliding_window is not None: | ||
| logger.info(f"Model has sliding_window={sliding_window}") | ||
| # Cap max_seq_len to sliding window size for cache allocation | ||
|
|
@@ -188,11 +192,16 @@ def _export_with_custom_components( | |
| else: | ||
| effective_cache_len = max_seq_len | ||
|
|
||
| # The HF ExecuTorch cache wrappers validate both generation_config.use_cache | ||
| # and the text config's use_cache flag before constructing static caches. | ||
| model.generation_config.use_cache = True | ||
| model.generation_config.cache_implementation = "static" | ||
| model.generation_config.cache_config = { | ||
| "batch_size": 1, | ||
| "max_cache_len": effective_cache_len, | ||
| } | ||
| text_config = model.config.get_text_config() | ||
| text_config.use_cache = True | ||
| model.eval() | ||
|
|
||
| # Use HybridCache wrapper for sliding window models (stores cache as .cache), | ||
|
|
@@ -341,6 +350,7 @@ def _save_program(executorch_program, output_path: str) -> None: | |
|
|
||
| def export_llama_hf( | ||
| model_id: str, | ||
| revision: Optional[str], | ||
| output_path: str, | ||
| max_seq_len: int = 1024, | ||
| dtype: str = "bf16", | ||
|
|
@@ -372,6 +382,7 @@ def export_llama_hf( | |
| ) | ||
| _export_with_custom_components( | ||
| model_id=model_id, | ||
| revision=revision, | ||
| output_path=output_path, | ||
| max_seq_len=max_seq_len, | ||
| dtype=dtype, | ||
|
|
@@ -387,6 +398,7 @@ def export_llama_hf( | |
| logger.info("Using optimum-executorch pipeline (no custom components)") | ||
| _export_with_optimum( | ||
| model_id=model_id, | ||
| revision=revision, | ||
| output_path=output_path, | ||
| max_seq_len=max_seq_len, | ||
| dtype=dtype, | ||
|
|
@@ -408,6 +420,12 @@ def main(): | |
| default="unsloth/Llama-3.2-1B-Instruct", | ||
| help="HuggingFace model ID", | ||
| ) | ||
| parser.add_argument( | ||
| "--revision", | ||
| type=str, | ||
| default=None, | ||
| help="Optional HuggingFace model revision/commit to pin", | ||
| ) | ||
| parser.add_argument( | ||
| "--output", | ||
| type=str, | ||
|
|
@@ -447,6 +465,7 @@ def main(): | |
|
|
||
| export_llama_hf( | ||
| model_id=args.model_id, | ||
| revision=args.revision, | ||
| output_path=args.output, | ||
| max_seq_len=args.max_seq_len, | ||
| dtype=args.dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no embeeding?