Skip to content

Commit 719d2e8

Browse files
committed
Route Gemma 4 MLX export through optimum fallback path
1 parent 9d3f841 commit 719d2e8

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

backends/mlx/examples/llm/export_llm_hf.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,24 @@ def export_llama_hf(
421421
use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa)
422422
use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update)
423423
"""
424-
if use_custom_sdpa or use_custom_kv_cache:
424+
effective_use_custom_sdpa = use_custom_sdpa
425+
effective_use_custom_kv_cache = use_custom_kv_cache
426+
if model_id == _GEMMA4_MODEL_ID:
427+
if effective_use_custom_sdpa:
428+
logger.info(
429+
"Disabling custom SDPA for Gemma 4 and falling back to the baseline export path"
430+
)
431+
effective_use_custom_sdpa = False
432+
if effective_use_custom_kv_cache:
433+
logger.info(
434+
"Disabling custom KV cache for Gemma 4 and falling back to the baseline export path"
435+
)
436+
effective_use_custom_kv_cache = False
437+
438+
if effective_use_custom_sdpa or effective_use_custom_kv_cache:
425439
logger.info(
426-
f"Using custom components: sdpa={use_custom_sdpa}, "
427-
f"kv_cache={use_custom_kv_cache}"
440+
f"Using custom components: sdpa={effective_use_custom_sdpa}, "
441+
f"kv_cache={effective_use_custom_kv_cache}"
428442
)
429443
_export_with_custom_components(
430444
model_id=model_id,
@@ -434,8 +448,8 @@ def export_llama_hf(
434448
dtype=dtype,
435449
qlinear=qlinear,
436450
qembedding=qembedding,
437-
use_custom_sdpa=use_custom_sdpa,
438-
use_custom_kv_cache=use_custom_kv_cache,
451+
use_custom_sdpa=effective_use_custom_sdpa,
452+
use_custom_kv_cache=effective_use_custom_kv_cache,
439453
no_tie_word_embeddings=no_tie_word_embeddings,
440454
qlinear_group_size=qlinear_group_size,
441455
qembedding_group_size=qembedding_group_size,

0 commit comments

Comments
 (0)