Skip to content

BatchRotatingKVCache: rotated=False reloads as True from saved prompt cache #1250

@odysa

Description

@odysa

BatchRotatingKVCache.meta_state round-trips self.rotated through str() on save and bool() on load. Since bool("False") is True (any non-empty string is truthy), a cache with rotated=False reloads as rotated=True, silently corrupting the rotation mask.

Affects any sliding-window model using BatchRotatingKVCache with batched generation that goes through save_prompt_cache / load_prompt_cache.

Repro

import os, tempfile
import mlx.core as mx
from mlx_lm.models.cache import (
    BatchRotatingKVCache, save_prompt_cache, load_prompt_cache,
)

c = BatchRotatingKVCache(max_size=10, left_padding=[0])
c.update_and_fetch(mx.zeros((1, 1, 3, 4)), mx.zeros((1, 1, 3, 4)))

with tempfile.TemporaryDirectory() as d:
    path = os.path.join(d, "c.safetensors")
    save_prompt_cache(path, [c])
    [loaded] = load_prompt_cache(path)

assert c.rotated is False
assert loaded.rotated is True   # bug: should be False

Output:

before save: False ('10', '3', '3', 'False')
after  load: True  ('10', '3', '3', 'True')

Root cause

mlx_lm/models/cache.py:1306-1315:

@property
def meta_state(self):
    return tuple(map(str, (self.max_size, self._offset, self._idx, self.rotated)))

@meta_state.setter
def meta_state(self, v):
    self.max_size, self._offset, self._idx = map(int, v[:3])
    self.rotated = bool(v[3])   # bool("False") is True

Suggested fix

self.rotated = v[3] == "True"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions