Skip to content

Commit fc22a6e

Browse files
authored
fix(cache): handle directory cache entries in StructuredOutputCache (#326, #327) (#331)
1 parent d391a25 commit fc22a6e

2 files changed

Lines changed: 79 additions & 8 deletions

File tree

src/autointent/generation/_cache.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import logging
7+
import shutil
78
from concurrent.futures import ThreadPoolExecutor, as_completed
89
from pathlib import Path
910
from typing import TYPE_CHECKING, Any, TypeVar
@@ -43,6 +44,17 @@ def _get_structured_output_cache_path(dirname: str) -> Path:
4344
return Path(user_cache_dir("autointent")) / "structured_outputs" / dirname
4445

4546

47+
def _remove_cache_entry(path: Path) -> None:
48+
"""Remove a single on-disk cache entry.
49+
50+
Each entry is a *directory* (``PydanticModelDumper.dump`` writes
51+
``class_info.json`` + ``model_dump.json`` inside it), so eviction must use
52+
``rmtree`` rather than ``unlink``. ``ignore_errors`` keeps a missing or
53+
partially removed entry from raising during cleanup.
54+
"""
55+
shutil.rmtree(path, ignore_errors=True)
56+
57+
4658
class StructuredOutputCache:
4759
"""Cache for structured output results."""
4860

@@ -70,8 +82,10 @@ def _load_existing_cache(self) -> None:
7082
if not cache_dir.exists():
7183
return
7284

73-
# Get all cache files to process
74-
cache_files = [f for f in cache_dir.iterdir() if f.is_file()]
85+
# Each cache entry is a directory written by PydanticModelDumper, so
86+
# collect directories (filtering on is_file() here matched nothing and
87+
# silently disabled eager loading entirely).
88+
cache_files = [f for f in cache_dir.iterdir() if f.is_dir()]
7589

7690
if not cache_files:
7791
return
@@ -118,7 +132,7 @@ def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | N
118132
cached_data = PydanticModelDumper.load(cache_file)
119133
except (ValidationError, ImportError) as e:
120134
logger.warning("Failed to load cached item %s: %s", cache_file.name, e)
121-
cache_file.unlink(missing_ok=True)
135+
_remove_cache_entry(cache_file)
122136
else:
123137
return cache_file.name, cached_data
124138

@@ -184,10 +198,10 @@ def _load_from_disk(self, cache_key: str, output_model: type[T]) -> T | None:
184198
return cached_data
185199

186200
logger.warning("Cached data type mismatch on disk, removing invalid cache")
187-
cache_path.unlink()
201+
_remove_cache_entry(cache_path)
188202
except (ValidationError, ImportError) as e:
189203
logger.warning("Failed to load cached structured output from disk: %s", e)
190-
cache_path.unlink(missing_ok=True)
204+
_remove_cache_entry(cache_path)
191205

192206
return None
193207

@@ -271,10 +285,10 @@ async def _load_from_disk_async(self, cache_key: str, output_model: type[T]) ->
271285
return cached_data
272286

273287
logger.warning("Cached data type mismatch on disk, removing invalid cache")
274-
cache_path.unlink()
288+
_remove_cache_entry(cache_path)
275289
except (ValidationError, ImportError) as e:
276290
logger.warning("Failed to load cached structured output from disk: %s", e)
277-
cache_path.unlink(missing_ok=True)
291+
_remove_cache_entry(cache_path)
278292

279293
return None
280294

tests/generation/structured_output/test_cache_unit.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from __future__ import annotations
44

5+
import json
56
from typing import TYPE_CHECKING, Any
67

78
import pytest
89
from pydantic import BaseModel
910

10-
from autointent.generation._cache import StructuredOutputCache
11+
from autointent.generation._cache import StructuredOutputCache, _get_structured_output_cache_path
1112
from autointent.generation.chat_templates import Role
1213

1314
if TYPE_CHECKING:
@@ -98,3 +99,59 @@ async def test_async_disabled_cache_is_noop() -> None:
9899
cache = StructuredOutputCache(use_cache=False)
99100
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
100101
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) is None
102+
103+
104+
# --- Regression tests for the on-disk-cache bugs (#326 eager load, #327 eviction) ---
105+
# Disk entries are directories (PydanticModelDumper writes class_info.json +
106+
# model_dump.json), so eager load must collect directories and eviction must
107+
# rmtree rather than unlink.
108+
109+
110+
def test_eager_load_populates_memory_from_disk() -> None:
111+
"""A fresh instance eagerly batch-loads existing on-disk entries into memory (#326)."""
112+
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))
113+
114+
fresh = StructuredOutputCache(use_cache=True)
115+
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS)
116+
117+
# populated at construction by the eager load, before any get() call
118+
assert key in fresh._memory_cache
119+
assert isinstance(fresh._memory_cache[key], CacheModel)
120+
121+
122+
def test_eager_load_removes_corrupted_entry() -> None:
123+
"""A cache directory whose payload fails to load is skipped and cleaned up, not raised."""
124+
entry = _get_structured_output_cache_path("corrupted-entry")
125+
entry.mkdir(parents=True)
126+
(entry / "class_info.json").write_text(json.dumps({"name": CacheModel.__name__, "module": CacheModel.__module__}))
127+
# missing the required "value" field -> ValidationError on load
128+
(entry / "model_dump.json").write_text(json.dumps({"name": "x"}))
129+
130+
cache = StructuredOutputCache(use_cache=True) # eager load must not raise
131+
132+
assert not cache._memory_cache
133+
assert not entry.exists()
134+
135+
136+
def test_disk_type_mismatch_evicts_entry() -> None:
137+
"""A type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
138+
cache = StructuredOutputCache(use_cache=True)
139+
# plant a CacheModel at the key the cache derives for OtherModel inputs
140+
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
141+
cache._save_to_disk(key, CacheModel(name="x", value=1))
142+
cache._memory_cache.clear()
143+
144+
assert cache._load_from_disk(key, OtherModel) is None
145+
assert not _get_structured_output_cache_path(key).exists()
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_async_disk_type_mismatch_evicts_entry() -> None:
150+
"""Async type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
151+
cache = StructuredOutputCache(use_cache=True)
152+
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
153+
await cache._save_to_disk_async(key, CacheModel(name="x", value=1))
154+
cache._memory_cache.clear()
155+
156+
assert await cache._load_from_disk_async(key, OtherModel) is None
157+
assert not _get_structured_output_cache_path(key).exists()

0 commit comments

Comments
 (0)