@@ -282,52 +282,26 @@ def _export_with_custom_components(
282282 )
283283
284284 if use_custom_kv_cache :
285- if sliding_window is not None :
286- # Use ring buffer cache for sliding window models
287- from executorch .backends .mlx .llm .source_transformation import (
288- replace_hf_cache_with_mlx_ring_buffer ,
289- )
285+ from executorch .backends .mlx .llm .source_transformation import (
286+ replace_hf_cache_with_mlx ,
287+ )
290288
289+ if sliding_window is not None :
291290 logger .info (
292- f "Replacing StaticCache with RingBuffer KV cache "
293- f"(window_size= { effective_cache_len } )..."
291+ "Replacing HuggingFace StaticCache with HFStaticCache "
292+ f"(capped to sliding window: { effective_cache_len } )..."
294293 )
295- replace_hf_cache_with_mlx_ring_buffer (
296- exportable ,
297- model .config ,
298- max_batch_size = 1 ,
299- window_size = effective_cache_len ,
300- dtype = torch_dtype ,
301- )
302-
303- if use_custom_sdpa :
304- # Re-register attention with sliding window closure
305- from executorch .backends .mlx .llm .hf_attention import (
306- register_mlx_sliding_window_attention ,
307- )
308-
309- register_mlx_sliding_window_attention (exportable )
310- model .config ._attn_implementation = "mlx_sliding_window"
311- logger .info (
312- " Registered sliding window attention (mlx_sliding_window)"
313- )
314-
315- logger .info (" RingBuffer KV cache installed successfully" )
316294 else :
317- # Use standard linear cache for non-sliding-window models
318- from executorch .backends .mlx .llm .source_transformation import (
319- replace_hf_cache_with_mlx ,
320- )
321-
322295 logger .info ("Replacing HuggingFace StaticCache with HFStaticCache..." )
323- replace_hf_cache_with_mlx (
324- exportable ,
325- model .config ,
326- max_batch_size = 1 ,
327- max_cache_len = effective_cache_len ,
328- dtype = torch_dtype ,
329- )
330- logger .info (" HFStaticCache installed successfully" )
296+
297+ replace_hf_cache_with_mlx (
298+ exportable ,
299+ model .config ,
300+ max_batch_size = 1 ,
301+ max_cache_len = effective_cache_len ,
302+ dtype = torch_dtype ,
303+ )
304+ logger .info (" HFStaticCache installed successfully" )
331305
332306 from executorch .backends .mlx .llm .quantization import quantize_model_
333307
0 commit comments