Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions src/autointent/generation/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,22 @@ def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | N

return None

def _get_cache_key(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> str:
def _get_cache_key(
self,
messages: list[Message],
output_model: type[T],
generation_params: dict[str, Any],
model_name: str,
base_url: str | None,
) -> str:
"""Generate a cache key for the given parameters.

Args:
messages: List of messages to send to the model.
output_model: Pydantic model class to parse the response into.
generation_params: Generation parameters.
model_name: Name of the language model that will serve the request.
base_url: Base URL of the API endpoint, or None for the default.

Returns:
Cache key as a hexadecimal string.
Expand All @@ -153,6 +162,8 @@ def _get_cache_key(self, messages: list[Message], output_model: type[T], generat
hasher.update(json.dumps(messages))
hasher.update(json.dumps(output_model.model_json_schema()))
hasher.update(json.dumps(generation_params))
hasher.update(model_name)
hasher.update(base_url)
return hasher.hexdigest()

def _check_memory_cache(self, cache_key: str, output_model: type[T]) -> T | None:
Expand Down Expand Up @@ -216,21 +227,30 @@ def _save_to_disk(self, cache_key: str, result: T) -> None:
cache_path.parent.mkdir(parents=True, exist_ok=True)
PydanticModelDumper.dump(result, cache_path, exists_ok=True)

def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None:
def get(
self,
messages: list[Message],
output_model: type[T],
generation_params: dict[str, Any],
model_name: str,
base_url: str | None,
) -> T | None:
"""Get cached result if available.

Args:
messages: List of messages to send to the model.
output_model: Pydantic model class to parse the response into.
generation_params: Generation parameters.
model_name: Name of the language model that will serve the request.
base_url: Base URL of the API endpoint, or None for the default.

Returns:
Cached result if available, None otherwise.
"""
if not self.use_cache:
return None

cache_key = self._get_cache_key(messages, output_model, generation_params)
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)

# First check in-memory cache
memory_result = self._check_memory_cache(cache_key, output_model)
Expand All @@ -240,20 +260,29 @@ def get(self, messages: list[Message], output_model: type[T], generation_params:
# Fallback to disk cache
return self._load_from_disk(cache_key, output_model)

def set(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T) -> None:
def set(
self,
messages: list[Message],
output_model: type[T],
generation_params: dict[str, Any],
result: T,
model_name: str,
base_url: str | None,
) -> None:
"""Cache the result.

Args:
messages: List of messages to send to the model.
output_model: Pydantic model class to parse the response into.
backend: Backend to use for structured output.
generation_params: Generation parameters.
result: The result to cache.
model_name: Name of the language model that will serve the request.
base_url: Base URL of the API endpoint, or None for the default.
"""
if not self.use_cache:
return

cache_key = self._get_cache_key(messages, output_model, generation_params)
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)

# Store in memory cache
self._memory_cache[cache_key] = result
Expand Down Expand Up @@ -304,22 +333,29 @@ async def _save_to_disk_async(self, cache_key: str, result: T) -> None:
await PydanticModelDumper.dump_async(result, cache_path, exists_ok=True)

async def get_async(
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]
self,
messages: list[Message],
output_model: type[T],
generation_params: dict[str, Any],
model_name: str,
base_url: str | None,
) -> T | None:
"""Get cached result if available (async version).

Args:
messages: List of messages to send to the model.
output_model: Pydantic model class to parse the response into.
generation_params: Generation parameters.
model_name: Name of the language model that will serve the request.
base_url: Base URL of the API endpoint, or None for the default.

Returns:
Cached result if available, None otherwise.
"""
if not self.use_cache:
return None

cache_key = self._get_cache_key(messages, output_model, generation_params)
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)

# First check in-memory cache
memory_result = self._check_memory_cache(cache_key, output_model)
Expand All @@ -330,21 +366,28 @@ async def get_async(
return await self._load_from_disk_async(cache_key, output_model)

async def set_async(
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T
self,
messages: list[Message],
output_model: type[T],
generation_params: dict[str, Any],
result: T,
model_name: str,
base_url: str | None,
) -> None:
"""Cache the result (async version).

Args:
messages: List of messages to send to the model.
output_model: Pydantic model class to parse the response into.
backend: Backend to use for structured output.
generation_params: Generation parameters.
result: The result to cache.
model_name: Name of the language model that will serve the request.
base_url: Base URL of the API endpoint, or None for the default.
"""
if not self.use_cache:
return

cache_key = self._get_cache_key(messages, output_model, generation_params)
cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url)

# Store in memory cache
self._memory_cache[cache_key] = result
Expand Down
10 changes: 6 additions & 4 deletions src/autointent/generation/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ async def get_structured_output_async(
Parsed response as an instance of the provided Pydantic model.
"""
# Check cache first
cached_result = await self.cache.get_async(messages, output_model, self.generation_params)
cached_result = await self.cache.get_async(
messages, output_model, self.generation_params, self.model_name, self.base_url
)
if cached_result is not None:
return cached_result

Expand Down Expand Up @@ -292,7 +294,7 @@ async def get_structured_output_async(
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)

# Cache the successful result
await self.cache.set_async(messages, output_model, self.generation_params, res)
await self.cache.set_async(messages, output_model, self.generation_params, res, self.model_name, self.base_url)

return res

Expand Down Expand Up @@ -350,7 +352,7 @@ def get_structured_output_sync(
Parsed response as an instance of the provided Pydantic model.
"""
# Check cache first
cached_result = self.cache.get(messages, output_model, self.generation_params)
cached_result = self.cache.get(messages, output_model, self.generation_params, self.model_name, self.base_url)
if cached_result is not None:
return cached_result

Expand All @@ -376,7 +378,7 @@ def get_structured_output_sync(
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)

# Cache the successful result
self.cache.set(messages, output_model, self.generation_params, res)
self.cache.set(messages, output_model, self.generation_params, res, self.model_name, self.base_url)

return res

Expand Down
121 changes: 103 additions & 18 deletions tests/generation/structured_output/test_cache_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,37 @@ def _isolated_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("autointent.generation._cache.user_cache_dir", lambda *_: str(tmp_path))


MODEL_NAME = "test-model"
BASE_URL: str | None = None


def test_set_get_memory_roundtrip() -> None:
cache = StructuredOutputCache(use_cache=True)
result = CacheModel(name="a", value=1)
cache.set(MESSAGES, CacheModel, PARAMS, result)
assert cache.get(MESSAGES, CacheModel, PARAMS) == result
cache.set(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL)
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result


def test_disabled_cache_is_noop() -> None:
cache = StructuredOutputCache(use_cache=False)
cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
assert cache.get(MESSAGES, CacheModel, PARAMS) is None
cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL)
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None


def test_get_misses_for_unknown_key() -> None:
cache = StructuredOutputCache(use_cache=True)
assert cache.get(MESSAGES, CacheModel, PARAMS) is None
assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None


def test_get_loads_from_disk_in_fresh_instance() -> None:
"""A second instance has empty memory and must read the entry back from disk."""
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))
StructuredOutputCache(use_cache=True).set(
MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL
)

fresh = StructuredOutputCache(use_cache=True)
fresh._memory_cache.clear() # force the disk path even if eager load changes
loaded = fresh.get(MESSAGES, CacheModel, PARAMS)
loaded = fresh.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
assert isinstance(loaded, CacheModel)
assert loaded.value == 9
# disk hit populates the memory cache for next time
Expand All @@ -69,7 +75,7 @@ def test_get_loads_from_disk_in_fresh_instance() -> None:

def test_memory_type_mismatch_evicts() -> None:
cache = StructuredOutputCache(use_cache=True)
key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS)
key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
cache._memory_cache[key] = OtherModel(text="wrong")
assert cache._check_memory_cache(key, CacheModel) is None
assert key not in cache._memory_cache
Expand All @@ -79,26 +85,28 @@ def test_memory_type_mismatch_evicts() -> None:
async def test_async_set_get_roundtrip() -> None:
cache = StructuredOutputCache(use_cache=True)
result = CacheModel(name="async", value=7)
await cache.set_async(MESSAGES, CacheModel, PARAMS, result)
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) == result
await cache.set_async(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL)
assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result


@pytest.mark.asyncio
async def test_async_get_loads_from_disk() -> None:
await StructuredOutputCache(use_cache=True).set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3))
await StructuredOutputCache(use_cache=True).set_async(
MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3), MODEL_NAME, BASE_URL
)

fresh = StructuredOutputCache(use_cache=True)
fresh._memory_cache.clear()
loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS)
loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)
assert isinstance(loaded, CacheModel)
assert loaded.value == 3


@pytest.mark.asyncio
async def test_async_disabled_cache_is_noop() -> None:
cache = StructuredOutputCache(use_cache=False)
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1))
assert await cache.get_async(MESSAGES, CacheModel, PARAMS) is None
await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL)
assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None


# --- Regression tests for the on-disk-cache bugs (#326 eager load, #327 eviction) ---
Expand All @@ -109,10 +117,12 @@ async def test_async_disabled_cache_is_noop() -> None:

def test_eager_load_populates_memory_from_disk() -> None:
"""A fresh instance eagerly batch-loads existing on-disk entries into memory (#326)."""
StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9))
StructuredOutputCache(use_cache=True).set(
MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL
)

fresh = StructuredOutputCache(use_cache=True)
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS)
key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL)

# populated at construction by the eager load, before any get() call
assert key in fresh._memory_cache
Expand All @@ -137,7 +147,7 @@ def test_disk_type_mismatch_evicts_entry() -> None:
"""A type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
cache = StructuredOutputCache(use_cache=True)
# plant a CacheModel at the key the cache derives for OtherModel inputs
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL)
cache._save_to_disk(key, CacheModel(name="x", value=1))
cache._memory_cache.clear()

Expand All @@ -149,9 +159,84 @@ def test_disk_type_mismatch_evicts_entry() -> None:
async def test_async_disk_type_mismatch_evicts_entry() -> None:
"""Async type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327)."""
cache = StructuredOutputCache(use_cache=True)
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS)
key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL)
await cache._save_to_disk_async(key, CacheModel(name="x", value=1))
cache._memory_cache.clear()

assert await cache._load_from_disk_async(key, OtherModel) is None
assert not _get_structured_output_cache_path(key).exists()


# --- Regression tests for model-identity cache collision (#334) ---


def test_different_model_names_do_not_collide() -> None:
"""Two generators with different model_name must NOT share a cache entry (#334)."""
result_a = CacheModel(name="from-model-a", value=1)
result_b = CacheModel(name="from-model-b", value=2)

cache = StructuredOutputCache(use_cache=True)
cache.set(MESSAGES, CacheModel, PARAMS, result_a, model_name="model-a", base_url=None)
cache.set(MESSAGES, CacheModel, PARAMS, result_b, model_name="model-b", base_url=None)

hit_a = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-a", base_url=None)
hit_b = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-b", base_url=None)

assert hit_a == result_a, "model-a should get its own cached value"
assert hit_b == result_b, "model-b must NOT get model-a's value"


def test_different_base_urls_do_not_collide() -> None:
"""Two generators with different base_url must NOT share a cache entry (#334)."""
result_x = CacheModel(name="from-url-x", value=10)
result_y = CacheModel(name="from-url-y", value=20)

cache = StructuredOutputCache(use_cache=True)
cache.set(MESSAGES, CacheModel, PARAMS, result_x, model_name="gpt-4o", base_url="http://host-x/v1")
cache.set(MESSAGES, CacheModel, PARAMS, result_y, model_name="gpt-4o", base_url="http://host-y/v1")

hit_x = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-x/v1")
hit_y = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-y/v1")

assert hit_x == result_x, "host-x should get its own cached value"
assert hit_y == result_y, "host-y must NOT get host-x's value"


def test_same_identity_still_hits_cache() -> None:
"""Same model_name + base_url + inputs must continue to yield a cache hit (#334)."""
result = CacheModel(name="same", value=42)

cache = StructuredOutputCache(use_cache=True)
cache.set(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url="http://host/v1")

hit = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host/v1")
assert hit == result


@pytest.mark.asyncio
async def test_async_different_model_names_do_not_collide() -> None:
"""Async paths: two model names must NOT collide (#334)."""
result_a = CacheModel(name="async-a", value=1)
result_b = CacheModel(name="async-b", value=2)

cache = StructuredOutputCache(use_cache=True)
await cache.set_async(MESSAGES, CacheModel, PARAMS, result_a, model_name="async-model-a", base_url=None)
await cache.set_async(MESSAGES, CacheModel, PARAMS, result_b, model_name="async-model-b", base_url=None)

hit_a = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-a", base_url=None)
hit_b = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-b", base_url=None)

assert hit_a == result_a
assert hit_b == result_b


@pytest.mark.asyncio
async def test_async_same_identity_still_hits_cache() -> None:
"""Async paths: same identity must still yield a hit (#334)."""
result = CacheModel(name="async-same", value=99)

cache = StructuredOutputCache(use_cache=True)
await cache.set_async(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url=None)

hit = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url=None)
assert hit == result
Loading