Skip to content

Commit 19d6f09

Browse files
committed
Use static cache for Gemma 4 MLX custom export
1 parent 391cde4 commit 19d6f09

1 file changed

Lines changed: 15 additions & 41 deletions

File tree

backends/mlx/examples/llm/export_llm_hf.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -282,52 +282,26 @@ def _export_with_custom_components(
282282
)
283283

284284
if use_custom_kv_cache:
285-
if sliding_window is not None:
286-
# Use ring buffer cache for sliding window models
287-
from executorch.backends.mlx.llm.source_transformation import (
288-
replace_hf_cache_with_mlx_ring_buffer,
289-
)
285+
from executorch.backends.mlx.llm.source_transformation import (
286+
replace_hf_cache_with_mlx,
287+
)
290288

289+
if sliding_window is not None:
291290
logger.info(
292-
f"Replacing StaticCache with RingBuffer KV cache "
293-
f"(window_size={effective_cache_len})..."
291+
"Replacing HuggingFace StaticCache with HFStaticCache "
292+
f"(capped to sliding window: {effective_cache_len})..."
294293
)
295-
replace_hf_cache_with_mlx_ring_buffer(
296-
exportable,
297-
model.config,
298-
max_batch_size=1,
299-
window_size=effective_cache_len,
300-
dtype=torch_dtype,
301-
)
302-
303-
if use_custom_sdpa:
304-
# Re-register attention with sliding window closure
305-
from executorch.backends.mlx.llm.hf_attention import (
306-
register_mlx_sliding_window_attention,
307-
)
308-
309-
register_mlx_sliding_window_attention(exportable)
310-
model.config._attn_implementation = "mlx_sliding_window"
311-
logger.info(
312-
" Registered sliding window attention (mlx_sliding_window)"
313-
)
314-
315-
logger.info(" RingBuffer KV cache installed successfully")
316294
else:
317-
# Use standard linear cache for non-sliding-window models
318-
from executorch.backends.mlx.llm.source_transformation import (
319-
replace_hf_cache_with_mlx,
320-
)
321-
322295
logger.info("Replacing HuggingFace StaticCache with HFStaticCache...")
323-
replace_hf_cache_with_mlx(
324-
exportable,
325-
model.config,
326-
max_batch_size=1,
327-
max_cache_len=effective_cache_len,
328-
dtype=torch_dtype,
329-
)
330-
logger.info(" HFStaticCache installed successfully")
296+
297+
replace_hf_cache_with_mlx(
298+
exportable,
299+
model.config,
300+
max_batch_size=1,
301+
max_cache_len=effective_cache_len,
302+
dtype=torch_dtype,
303+
)
304+
logger.info(" HFStaticCache installed successfully")
331305

332306
from executorch.backends.mlx.llm.quantization import quantize_model_
333307

0 commit comments

Comments
 (0)