4747logging .basicConfig (level = logging .INFO , format = FORMAT )
4848logger = logging .getLogger (__name__ )
4949
50- _GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
51- _GEMMA4_PROBLEM_LAYER_FQN = "model.language_model.layers.22.mlp.down_proj"
52-
53-
54- def _get_submodule_by_fqn (root : torch .nn .Module , fqn : str ) -> torch .nn .Module :
55- cur = root
56- for part in fqn .split ("." ):
57- if part .isdigit ():
58- cur = cur [int (part )] # type: ignore[index]
59- else :
60- cur = getattr (cur , part )
61- return cur
62-
63-
64- def _capture_gemma4_float_fallback_weight (
65- model_id : str ,
66- qlinear : Optional [str ],
67- model : torch .nn .Module ,
68- ) -> Optional [torch .Tensor ]:
69- if model_id != _GEMMA4_MODEL_ID or qlinear != "4w" :
70- return None
71-
72- layer = _get_submodule_by_fqn (model , _GEMMA4_PROBLEM_LAYER_FQN )
73- weight = layer .weight .detach ().clone ()
74- logger .info (
75- "Saving %s in floating point to avoid the current Gemma 4 4w mismatch" ,
76- _GEMMA4_PROBLEM_LAYER_FQN ,
77- )
78- return weight
79-
80-
81- def _restore_gemma4_float_fallback_weight (
82- model_id : str ,
83- qlinear : Optional [str ],
84- model : torch .nn .Module ,
85- weight : Optional [torch .Tensor ],
86- ) -> None :
87- if weight is None or model_id != _GEMMA4_MODEL_ID or qlinear != "4w" :
88- return
89-
90- layer = _get_submodule_by_fqn (model , _GEMMA4_PROBLEM_LAYER_FQN )
91- layer .weight = torch .nn .Parameter (weight , requires_grad = False )
92- logger .info (
93- "Restored %s in floating point after quantization" ,
94- _GEMMA4_PROBLEM_LAYER_FQN ,
95- )
96-
9750
9851def _export_with_optimum (
9952 model_id : str ,
@@ -128,10 +81,6 @@ def _export_with_optimum(
12881
12982 from executorch .backends .mlx .llm .quantization import quantize_model_
13083
131- gemma4_float_weight = _capture_gemma4_float_fallback_weight (
132- model_id , qlinear , exportable .model
133- )
134-
13584 quantize_model_ (
13685 exportable .model ,
13786 qlinear_config = qlinear ,
@@ -143,9 +92,6 @@ def _export_with_optimum(
14392 )
14493 and not no_tie_word_embeddings ,
14594 )
146- _restore_gemma4_float_fallback_weight (
147- model_id , qlinear , exportable .model , gemma4_float_weight
148- )
14995
15096 logger .info ("Exporting model with torch.export..." )
15197 exported_progs = exportable .export ()
@@ -215,24 +161,13 @@ def _export_with_custom_components(
215161 }
216162 torch_dtype = torch_dtype_map .get (dtype , torch .bfloat16 )
217163
218- effective_use_custom_sdpa = use_custom_sdpa
219- effective_use_custom_kv_cache = use_custom_kv_cache
220- if model_id == _GEMMA4_MODEL_ID and use_custom_sdpa :
221- logger .info (
222- "Disabling custom SDPA for Gemma 4 while keeping the custom cache path"
223- )
224- effective_use_custom_sdpa = False
225- if model_id == _GEMMA4_MODEL_ID and use_custom_kv_cache :
226- logger .info ("Disabling custom KV cache for Gemma 4" )
227- effective_use_custom_kv_cache = False
228-
229- if effective_use_custom_sdpa :
164+ if use_custom_sdpa :
230165 from executorch .backends .mlx .llm .hf_attention import register_mlx_attention
231166
232167 register_mlx_attention ()
233168 logger .info ("Registered MLX custom SDPA attention" )
234169
235- attn_implementation = "mlx" if effective_use_custom_sdpa else None
170+ attn_implementation = "mlx" if use_custom_sdpa else None
236171
237172 logger .info (f"Loading HuggingFace model: { model_id } " )
238173 load_kwargs = {
@@ -292,7 +227,7 @@ def _export_with_custom_components(
292227 max_cache_len = effective_cache_len ,
293228 )
294229
295- if effective_use_custom_kv_cache :
230+ if use_custom_kv_cache :
296231 from executorch .backends .mlx .llm .source_transformation import (
297232 replace_hf_cache_with_mlx ,
298233 )
@@ -316,10 +251,6 @@ def _export_with_custom_components(
316251
317252 from executorch .backends .mlx .llm .quantization import quantize_model_
318253
319- gemma4_float_weight = _capture_gemma4_float_fallback_weight (
320- model_id , qlinear , exportable .model
321- )
322-
323254 quantize_model_ (
324255 exportable .model ,
325256 qlinear_config = qlinear ,
@@ -329,9 +260,6 @@ def _export_with_custom_components(
329260 tie_word_embeddings = getattr (model .config , "tie_word_embeddings" , False )
330261 and not no_tie_word_embeddings ,
331262 )
332- _restore_gemma4_float_fallback_weight (
333- model_id , qlinear , exportable .model , gemma4_float_weight
334- )
335263
336264 logger .info ("Exporting model with torch.export..." )
337265 seq_length = 3
@@ -421,24 +349,10 @@ def export_llama_hf(
421349 use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa)
422350 use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update)
423351 """
424- effective_use_custom_sdpa = use_custom_sdpa
425- effective_use_custom_kv_cache = use_custom_kv_cache
426- if model_id == _GEMMA4_MODEL_ID :
427- if effective_use_custom_sdpa :
428- logger .info (
429- "Disabling custom SDPA for Gemma 4 and falling back to the baseline export path"
430- )
431- effective_use_custom_sdpa = False
432- if effective_use_custom_kv_cache :
433- logger .info (
434- "Disabling custom KV cache for Gemma 4 and falling back to the baseline export path"
435- )
436- effective_use_custom_kv_cache = False
437-
438- if effective_use_custom_sdpa or effective_use_custom_kv_cache :
352+ if use_custom_sdpa or use_custom_kv_cache :
439353 logger .info (
440- f"Using custom components: sdpa={ effective_use_custom_sdpa } , "
441- f"kv_cache={ effective_use_custom_kv_cache } "
354+ f"Using custom components: sdpa={ use_custom_sdpa } , "
355+ f"kv_cache={ use_custom_kv_cache } "
442356 )
443357 _export_with_custom_components (
444358 model_id = model_id ,
@@ -448,8 +362,8 @@ def export_llama_hf(
448362 dtype = dtype ,
449363 qlinear = qlinear ,
450364 qembedding = qembedding ,
451- use_custom_sdpa = effective_use_custom_sdpa ,
452- use_custom_kv_cache = effective_use_custom_kv_cache ,
365+ use_custom_sdpa = use_custom_sdpa ,
366+ use_custom_kv_cache = use_custom_kv_cache ,
453367 no_tie_word_embeddings = no_tie_word_embeddings ,
454368 qlinear_group_size = qlinear_group_size ,
455369 qembedding_group_size = qembedding_group_size ,
0 commit comments