@@ -421,10 +421,24 @@ def export_llama_hf(
421421 use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa)
422422 use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update)
423423 """
424- if use_custom_sdpa or use_custom_kv_cache :
424+ effective_use_custom_sdpa = use_custom_sdpa
425+ effective_use_custom_kv_cache = use_custom_kv_cache
426+ if model_id == _GEMMA4_MODEL_ID :
427+ if effective_use_custom_sdpa :
428+ logger .info (
429+ "Disabling custom SDPA for Gemma 4 and falling back to the baseline export path"
430+ )
431+ effective_use_custom_sdpa = False
432+ if effective_use_custom_kv_cache :
433+ logger .info (
434+ "Disabling custom KV cache for Gemma 4 and falling back to the baseline export path"
435+ )
436+ effective_use_custom_kv_cache = False
437+
438+ if effective_use_custom_sdpa or effective_use_custom_kv_cache :
425439 logger .info (
426- f"Using custom components: sdpa={ use_custom_sdpa } , "
427- f"kv_cache={ use_custom_kv_cache } "
440+ f"Using custom components: sdpa={ effective_use_custom_sdpa } , "
441+ f"kv_cache={ effective_use_custom_kv_cache } "
428442 )
429443 _export_with_custom_components (
430444 model_id = model_id ,
@@ -434,8 +448,8 @@ def export_llama_hf(
434448 dtype = dtype ,
435449 qlinear = qlinear ,
436450 qembedding = qembedding ,
437- use_custom_sdpa = use_custom_sdpa ,
438- use_custom_kv_cache = use_custom_kv_cache ,
451+ use_custom_sdpa = effective_use_custom_sdpa ,
452+ use_custom_kv_cache = effective_use_custom_kv_cache ,
439453 no_tie_word_embeddings = no_tie_word_embeddings ,
440454 qlinear_group_size = qlinear_group_size ,
441455 qembedding_group_size = qembedding_group_size ,
0 commit comments