Skip to content

Commit a855ef0

Browse files
authored
fix(cache): include model identity in structured-output cache key (#334) (#336)
1 parent 9a23dca commit a855ef0

4 files changed

Lines changed: 165 additions & 33 deletions

File tree

src/autointent/generation/_cache.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,22 @@ def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | N
138138

139139
return None
140140

141-
def _get_cache_key(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> str:
141+
def _get_cache_key(
142+
self,
143+
messages: list[Message],
144+
output_model: type[T],
145+
generation_params: dict[str, Any],
146+
model_name: str,
147+
base_url: str | None,
148+
) -> str:
142149
"""Generate a cache key for the given parameters.
143150
144151
Args:
145152
messages: List of messages to send to the model.
146153
output_model: Pydantic model class to parse the response into.
147154
generation_params: Generation parameters.
155+
model_name: Name of the language model that will serve the request.
156+
base_url: Base URL of the API endpoint, or None for the default.
148157
149158
Returns:
150159
Cache key as a hexadecimal string.
@@ -153,6 +162,8 @@ def _get_cache_key(self, messages: list[Message], output_model: type[T], generat
153162
hasher.update(json.dumps(messages))
154163
hasher.update(json.dumps(output_model.model_json_schema()))
155164
hasher.update(json.dumps(generation_params))
165+
hasher.update(model_name)
166+
hasher.update(base_url)
156167
return hasher.hexdigest()
157168

158169
def _check_memory_cache(self, cache_key: str, output_model: type[T]) -> T | None:
@@ -216,21 +227,30 @@ def _save_to_disk(self, cache_key: str, result: T) -> None:
216227
cache_path.parent.mkdir(parents=True, exist_ok=True)
217228
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
218229

219-
def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None:
230+
def get(
231+
self,
232+
messages: list[Message],
233+
output_model: type[T],
234+
generation_params: dict[str, Any],
235+
model_name: str,
236+
base_url: str | None,
237+
) -> T | None:
220238
"""Get cached result if available.
221239
222240
Args:
223241
messages: List of messages to send to the model.
224242
output_model: Pydantic model class to parse the response into.
225243
generation_params: Generation parameters.
244+
model_name: Name of the language model that will serve the request.
245+
base_url: Base URL of the API endpoint, or None for the default.
226246
227247
Returns:
228248
Cached result if available, None otherwise.
229249
"""
230250
if not self.use_cache:
231251
return None
232252

233-
cache_key = self._get_cache_key(messages, output_model, generation_params)
253+
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)
234254

235255
# First check in-memory cache
236256
memory_result = self._check_memory_cache(cache_key, output_model)
@@ -240,20 +260,29 @@ def get(self, messages: list[Message], output_model: type[T], generation_params:
240260
# Fallback to disk cache
241261
return self._load_from_disk(cache_key, output_model)
242262

243-
def set(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T) -> None:
263+
def set(
264+
self,
265+
messages: list[Message],
266+
output_model: type[T],
267+
generation_params: dict[str, Any],
268+
result: T,
269+
model_name: str,
270+
base_url: str | None,
271+
) -> None:
244272
"""Cache the result.
245273
246274
Args:
247275
messages: List of messages to send to the model.
248276
output_model: Pydantic model class to parse the response into.
249-
backend: Backend to use for structured output.
250277
generation_params: Generation parameters.
251278
result: The result to cache.
279+
model_name: Name of the language model that will serve the request.
280+
base_url: Base URL of the API endpoint, or None for the default.
252281
"""
253282
if not self.use_cache:
254283
return
255284

256-
cache_key = self._get_cache_key(messages, output_model, generation_params)
285+
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)
257286

258287
# Store in memory cache
259288
self._memory_cache[cache_key] = result
@@ -304,22 +333,29 @@ async def _save_to_disk_async(self, cache_key: str, result: T) -> None:
304333
await PydanticModelDumper.dump_async(result, cache_path, exists_ok=True)
305334

306335
async def get_async(
307-
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]
336+
self,
337+
messages: list[Message],
338+
output_model: type[T],
339+
generation_params: dict[str, Any],
340+
model_name: str,
341+
base_url: str | None,
308342
) -> T | None:
309343
"""Get cached result if available (async version).
310344
311345
Args:
312346
messages: List of messages to send to the model.
313347
output_model: Pydantic model class to parse the response into.
314348
generation_params: Generation parameters.
349+
model_name: Name of the language model that will serve the request.
350+
base_url: Base URL of the API endpoint, or None for the default.
315351
316352
Returns:
317353
Cached result if available, None otherwise.
318354
"""
319355
if not self.use_cache:
320356
return None
321357

322-
cache_key = self._get_cache_key(messages, output_model, generation_params)
358+
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)
323359

324360
# First check in-memory cache
325361
memory_result = self._check_memory_cache(cache_key, output_model)
@@ -330,21 +366,28 @@ async def get_async(
330366
return await self._load_from_disk_async(cache_key, output_model)
331367

332368
async def set_async(
333-
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T
369+
self,
370+
messages: list[Message],
371+
output_model: type[T],
372+
generation_params: dict[str, Any],
373+
result: T,
374+
model_name: str,
375+
base_url: str | None,
334376
) -> None:
335377
"""Cache the result (async version).
336378
337379
Args:
338380
messages: List of messages to send to the model.
339381
output_model: Pydantic model class to parse the response into.
340-
backend: Backend to use for structured output.
341382
generation_params: Generation parameters.
342383
result: The result to cache.
384+
model_name: Name of the language model that will serve the request.
385+
base_url: Base URL of the API endpoint, or None for the default.
343386
"""
344387
if not self.use_cache:
345388
return
346389

347-
cache_key = self._get_cache_key(messages, output_model, generation_params)
390+
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)
348391

349392
# Store in memory cache
350393
self._memory_cache[cache_key] = result

src/autointent/generation/_generator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ async def get_structured_output_async(
263263
Parsed response as an instance of the provided Pydantic model.
264264
"""
265265
# Check cache first
266-
cached_result = await self.cache.get_async(messages, output_model, self.generation_params)
266+
cached_result = await self.cache.get_async(
267+
messages, output_model, self.generation_params, self.model_name, self.base_url
268+
)
267269
if cached_result is not None:
268270
return cached_result
269271

@@ -292,7 +294,7 @@ async def get_structured_output_async(
292294
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
293295

294296
# Cache the successful result
295-
await self.cache.set_async(messages, output_model, self.generation_params, res)
297+
await self.cache.set_async(messages, output_model, self.generation_params, res, self.model_name, self.base_url)
296298

297299
return res
298300

@@ -350,7 +352,7 @@ def get_structured_output_sync(
350352
Parsed response as an instance of the provided Pydantic model.
351353
"""
352354
# Check cache first
353-
cached_result = self.cache.get(messages, output_model, self.generation_params)
355+
cached_result = self.cache.get(messages, output_model, self.generation_params, self.model_name, self.base_url)
354356
if cached_result is not None:
355357
return cached_result
356358

@@ -376,7 +378,7 @@ def get_structured_output_sync(
376378
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
377379

378380
# Cache the successful result
379-
self.cache.set(messages, output_model, self.generation_params, res)
381+
self.cache.set(messages, output_model, self.generation_params, res, self.model_name, self.base_url)
380382

381383
return res
382384

tests/generation/structured_output/test_cache_unit.py

Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,37 @@ def _isolated_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
3636
monkeypatch.setattr("autointent.generation._cache.user_cache_dir", lambda *_: str(tmp_path))
3737

3838

39+
MODEL_NAME = "test-model"
40+
BASE_URL: str | None = None
41+
42+
3943
def test_set_get_memory_roundtrip() -> None:
4044
cache = StructuredOutputCache(use_cache=True)
4145
result = CacheModel(name="a", value=1)
42-
cache.set(MESSAGES, CacheModel, PARAMS, result)
43-
assert cache.get(MESSAGES, CacheModel, PARAMS) == result
46+
cache.set(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL)
47+
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result
4448

4549

4650
def test_disabled_cache_is_noop() -> None:
4751
cache = StructuredOutputCache(use_cache=False)
48-
cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
49-
assert cache.get(MESSAGES, CacheModel, PARAMS) is None
52+
cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL)
53+
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None
5054

5155

5256
def test_get_misses_for_unknown_key() -> None:
5357
cache = StructuredOutputCache(use_cache=True)
54-
assert cache.get(MESSAGES, CacheModel, PARAMS) is None
58+
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None
5559

5660

5761
def test_get_loads_from_disk_in_fresh_instance() -> None:
5862
"""A second instance has empty memory and must read the entry back from disk."""
59-
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))
63+
StructuredOutputCache(use_cache=True).set(
64+
MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL
65+
)
6066

6167
fresh = StructuredOutputCache(use_cache=True)
6268
fresh._memory_cache.clear() # force the disk path even if eager load changes
63-
loaded = fresh.get(MESSAGES, CacheModel, PARAMS)
69+
loaded = fresh.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
6470
assert isinstance(loaded, CacheModel)
6571
assert loaded.value == 9
6672
# disk hit populates the memory cache for next time
@@ -69,7 +75,7 @@ def test_get_loads_from_disk_in_fresh_instance() -> None:
6975

7076
def test_memory_type_mismatch_evicts() -> None:
7177
cache = StructuredOutputCache(use_cache=True)
72-
key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS)
78+
key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
7379
cache._memory_cache[key] = OtherModel(text="wrong")
7480
assert cache._check_memory_cache(key, CacheModel) is None
7581
assert key not in cache._memory_cache
@@ -79,26 +85,28 @@ def test_memory_type_mismatch_evicts() -> None:
7985
async def test_async_set_get_roundtrip() -> None:
8086
cache = StructuredOutputCache(use_cache=True)
8187
result = CacheModel(name="async", value=7)
82-
await cache.set_async(MESSAGES, CacheModel, PARAMS, result)
83-
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) == result
88+
await cache.set_async(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL)
89+
assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result
8490

8591

8692
@pytest.mark.asyncio
8793
async def test_async_get_loads_from_disk() -> None:
88-
await StructuredOutputCache(use_cache=True).set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3))
94+
await StructuredOutputCache(use_cache=True).set_async(
95+
MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3), MODEL_NAME, BASE_URL
96+
)
8997

9098
fresh = StructuredOutputCache(use_cache=True)
9199
fresh._memory_cache.clear()
92-
loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS)
100+
loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
93101
assert isinstance(loaded, CacheModel)
94102
assert loaded.value == 3
95103

96104

97105
@pytest.mark.asyncio
98106
async def test_async_disabled_cache_is_noop() -> None:
99107
cache = StructuredOutputCache(use_cache=False)
100-
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
101-
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) is None
108+
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL)
109+
assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None
102110

103111

104112
# --- Regression tests for the on-disk-cache bugs (#326 eager load, #327 eviction) ---
@@ -109,10 +117,12 @@ async def test_async_disabled_cache_is_noop() -> None:
109117

110118
def test_eager_load_populates_memory_from_disk() -> None:
111119
"""A fresh instance eagerly batch-loads existing on-disk entries into memory (#326)."""
112-
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))
120+
StructuredOutputCache(use_cache=True).set(
121+
MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL
122+
)
113123

114124
fresh = StructuredOutputCache(use_cache=True)
115-
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS)
125+
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
116126

117127
# populated at construction by the eager load, before any get() call
118128
assert key in fresh._memory_cache
@@ -137,7 +147,7 @@ def test_disk_type_mismatch_evicts_entry() -> None:
137147
"""A type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
138148
cache = StructuredOutputCache(use_cache=True)
139149
# plant a CacheModel at the key the cache derives for OtherModel inputs
140-
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
150+
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL)
141151
cache._save_to_disk(key, CacheModel(name="x", value=1))
142152
cache._memory_cache.clear()
143153

@@ -149,9 +159,84 @@ def test_disk_type_mismatch_evicts_entry() -> None:
149159
async def test_async_disk_type_mismatch_evicts_entry() -> None:
150160
"""Async type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
151161
cache = StructuredOutputCache(use_cache=True)
152-
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
162+
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL)
153163
await cache._save_to_disk_async(key, CacheModel(name="x", value=1))
154164
cache._memory_cache.clear()
155165

156166
assert await cache._load_from_disk_async(key, OtherModel) is None
157167
assert not _get_structured_output_cache_path(key).exists()
168+
169+
170+
# --- Regression tests for model-identity cache collision (#334) ---
171+
172+
173+
def test_different_model_names_do_not_collide() -> None:
174+
"""Two generators with different model_name must NOT share a cache entry (#334)."""
175+
result_a = CacheModel(name="from-model-a", value=1)
176+
result_b = CacheModel(name="from-model-b", value=2)
177+
178+
cache = StructuredOutputCache(use_cache=True)
179+
cache.set(MESSAGES, CacheModel, PARAMS, result_a, model_name="model-a", base_url=None)
180+
cache.set(MESSAGES, CacheModel, PARAMS, result_b, model_name="model-b", base_url=None)
181+
182+
hit_a = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-a", base_url=None)
183+
hit_b = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-b", base_url=None)
184+
185+
assert hit_a == result_a, "model-a should get its own cached value"
186+
assert hit_b == result_b, "model-b must NOT get model-a's value"
187+
188+
189+
def test_different_base_urls_do_not_collide() -> None:
190+
"""Two generators with different base_url must NOT share a cache entry (#334)."""
191+
result_x = CacheModel(name="from-url-x", value=10)
192+
result_y = CacheModel(name="from-url-y", value=20)
193+
194+
cache = StructuredOutputCache(use_cache=True)
195+
cache.set(MESSAGES, CacheModel, PARAMS, result_x, model_name="gpt-4o", base_url="http://host-x/v1")
196+
cache.set(MESSAGES, CacheModel, PARAMS, result_y, model_name="gpt-4o", base_url="http://host-y/v1")
197+
198+
hit_x = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-x/v1")
199+
hit_y = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-y/v1")
200+
201+
assert hit_x == result_x, "host-x should get its own cached value"
202+
assert hit_y == result_y, "host-y must NOT get host-x's value"
203+
204+
205+
def test_same_identity_still_hits_cache() -> None:
206+
"""Same model_name + base_url + inputs must continue to yield a cache hit (#334)."""
207+
result = CacheModel(name="same", value=42)
208+
209+
cache = StructuredOutputCache(use_cache=True)
210+
cache.set(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url="http://host/v1")
211+
212+
hit = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host/v1")
213+
assert hit == result
214+
215+
216+
@pytest.mark.asyncio
217+
async def test_async_different_model_names_do_not_collide() -> None:
218+
"""Async paths: two model names must NOT collide (#334)."""
219+
result_a = CacheModel(name="async-a", value=1)
220+
result_b = CacheModel(name="async-b", value=2)
221+
222+
cache = StructuredOutputCache(use_cache=True)
223+
await cache.set_async(MESSAGES, CacheModel, PARAMS, result_a, model_name="async-model-a", base_url=None)
224+
await cache.set_async(MESSAGES, CacheModel, PARAMS, result_b, model_name="async-model-b", base_url=None)
225+
226+
hit_a = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-a", base_url=None)
227+
hit_b = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-b", base_url=None)
228+
229+
assert hit_a == result_a
230+
assert hit_b == result_b
231+
232+
233+
@pytest.mark.asyncio
234+
async def test_async_same_identity_still_hits_cache() -> None:
235+
"""Async paths: same identity must still yield a hit (#334)."""
236+
result = CacheModel(name="async-same", value=99)
237+
238+
cache = StructuredOutputCache(use_cache=True)
239+
await cache.set_async(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url=None)
240+
241+
hit = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url=None)
242+
assert hit == result

0 commit comments

Comments
 (0)