Skip to content

Commit 818a51d

Browse files
committed
Prefer HF early cache init for Gemma 4 MLX path
1 parent ca37250 commit 818a51d

1 file changed

Lines changed: 25 additions & 17 deletions

File tree

backends/mlx/llm/cache.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -400,30 +400,38 @@ def __init__(
400400
device=device,
401401
dtype=dtype,
402402
)
403-
# The HF cache API pinned in CI expects scalar num_heads/head_dim in
404-
# early_initialization(). Gemma 4-style hybrid layouts need per-layer
405-
# shapes, so initialize each cache layer directly using the resolved
406-
# backing-cache geometry instead of relying on the helper.
407-
for layer, layer_num_heads, layer_head_dim in zip(
408-
self.layers, num_heads, head_dims
409-
):
410-
fake_keys_tensor = torch.zeros(
411-
(max_batch_size, layer_num_heads, 0, layer_head_dim),
403+
# Newer HF cache implementations already support per-layer layouts in
404+
# early_initialization(). Keep that path for Gemma 4, and only fall
405+
# back to manual layer initialization for the older CI-pinned API.
406+
try:
407+
self.early_initialization(
408+
batch_size=max_batch_size,
409+
num_heads=num_heads,
410+
head_dim=head_dims,
412411
dtype=dtype,
413412
device=device,
414413
)
415-
lazy_init_sig = inspect.signature(layer.lazy_initialization)
416-
# Older pinned HF caches take a single fake tensor, while newer
417-
# versions expect both key_states and value_states separately.
418-
if len(lazy_init_sig.parameters) == 1:
419-
layer.lazy_initialization(fake_keys_tensor)
420-
else:
421-
fake_values_tensor = torch.zeros(
414+
except TypeError:
415+
for layer, layer_num_heads, layer_head_dim in zip(
416+
self.layers, num_heads, head_dims
417+
):
418+
fake_keys_tensor = torch.zeros(
422419
(max_batch_size, layer_num_heads, 0, layer_head_dim),
423420
dtype=dtype,
424421
device=device,
425422
)
426-
layer.lazy_initialization(fake_keys_tensor, fake_values_tensor)
423+
lazy_init_sig = inspect.signature(layer.lazy_initialization)
424+
# Older pinned HF caches take a single fake tensor, while newer
425+
# versions expect both key_states and value_states separately.
426+
if len(lazy_init_sig.parameters) == 1:
427+
layer.lazy_initialization(fake_keys_tensor)
428+
else:
429+
fake_values_tensor = torch.zeros(
430+
(max_batch_size, layer_num_heads, 0, layer_head_dim),
431+
dtype=dtype,
432+
device=device,
433+
)
434+
layer.lazy_initialization(fake_keys_tensor, fake_values_tensor)
427435

428436
# Some models (for example Gemma 4) only allocate cache entries for the
429437
# non-shared KV layers. Mirror the parent StaticCache layout exactly so

0 commit comments

Comments
 (0)