diff --git a/src/autointent/generation/_cache.py b/src/autointent/generation/_cache.py index 8a34eb07a..87ae93163 100644 --- a/src/autointent/generation/_cache.py +++ b/src/autointent/generation/_cache.py @@ -138,13 +138,22 @@ def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | N return None - def _get_cache_key(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> str: + def _get_cache_key( + self, + messages: list[Message], + output_model: type[T], + generation_params: dict[str, Any], + model_name: str, + base_url: str | None, + ) -> str: """Generate a cache key for the given parameters. Args: messages: List of messages to send to the model. output_model: Pydantic model class to parse the response into. generation_params: Generation parameters. + model_name: Name of the language model that will serve the request. + base_url: Base URL of the API endpoint, or None for the default. Returns: Cache key as a hexadecimal string. @@ -153,6 +162,8 @@ def _get_cache_key(self, messages: list[Message], output_model: type[T], generat hasher.update(json.dumps(messages)) hasher.update(json.dumps(output_model.model_json_schema())) hasher.update(json.dumps(generation_params)) + hasher.update(model_name) + hasher.update(base_url) return hasher.hexdigest() def _check_memory_cache(self, cache_key: str, output_model: type[T]) -> T | None: @@ -216,13 +227,22 @@ def _save_to_disk(self, cache_key: str, result: T) -> None: cache_path.parent.mkdir(parents=True, exist_ok=True) PydanticModelDumper.dump(result, cache_path, exists_ok=True) - def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None: + def get( + self, + messages: list[Message], + output_model: type[T], + generation_params: dict[str, Any], + model_name: str, + base_url: str | None, + ) -> T | None: """Get cached result if available. Args: messages: List of messages to send to the model. output_model: Pydantic model class to parse the response into. generation_params: Generation parameters. + model_name: Name of the language model that will serve the request. + base_url: Base URL of the API endpoint, or None for the default. Returns: Cached result if available, None otherwise. @@ -230,7 +250,7 @@ def get(self, messages: list[Message], output_model: type[T], generation_params: if not self.use_cache: return None - cache_key = self._get_cache_key(messages, output_model, generation_params) + cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url) # First check in-memory cache memory_result = self._check_memory_cache(cache_key, output_model) @@ -240,20 +260,29 @@ def get(self, messages: list[Message], output_model: type[T], generation_params: # Fallback to disk cache return self._load_from_disk(cache_key, output_model) - def set(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T) -> None: + def set( + self, + messages: list[Message], + output_model: type[T], + generation_params: dict[str, Any], + result: T, + model_name: str, + base_url: str | None, + ) -> None: """Cache the result. Args: messages: List of messages to send to the model. output_model: Pydantic model class to parse the response into. - backend: Backend to use for structured output. generation_params: Generation parameters. result: The result to cache. + model_name: Name of the language model that will serve the request. + base_url: Base URL of the API endpoint, or None for the default. """ if not self.use_cache: return - cache_key = self._get_cache_key(messages, output_model, generation_params) + cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url) # Store in memory cache self._memory_cache[cache_key] = result @@ -304,7 +333,12 @@ async def _save_to_disk_async(self, cache_key: str, result: T) -> None: await PydanticModelDumper.dump_async(result, cache_path, exists_ok=True) async def get_async( - self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any] + self, + messages: list[Message], + output_model: type[T], + generation_params: dict[str, Any], + model_name: str, + base_url: str | None, ) -> T | None: """Get cached result if available (async version). @@ -312,6 +346,8 @@ async def get_async( messages: List of messages to send to the model. output_model: Pydantic model class to parse the response into. generation_params: Generation parameters. + model_name: Name of the language model that will serve the request. + base_url: Base URL of the API endpoint, or None for the default. Returns: Cached result if available, None otherwise. @@ -319,7 +355,7 @@ async def get_async( if not self.use_cache: return None - cache_key = self._get_cache_key(messages, output_model, generation_params) + cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url) # First check in-memory cache memory_result = self._check_memory_cache(cache_key, output_model) @@ -330,21 +366,28 @@ async def get_async( return await self._load_from_disk_async(cache_key, output_model) async def set_async( - self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T + self, + messages: list[Message], + output_model: type[T], + generation_params: dict[str, Any], + result: T, + model_name: str, + base_url: str | None, ) -> None: """Cache the result (async version). Args: messages: List of messages to send to the model. output_model: Pydantic model class to parse the response into. - backend: Backend to use for structured output. generation_params: Generation parameters. result: The result to cache. + model_name: Name of the language model that will serve the request. + base_url: Base URL of the API endpoint, or None for the default. """ if not self.use_cache: return - cache_key = self._get_cache_key(messages, output_model, generation_params) + cache_key = self._get_cache_key(messages, output_model, generation_params, model_name, base_url) # Store in memory cache self._memory_cache[cache_key] = result diff --git a/src/autointent/generation/_generator.py b/src/autointent/generation/_generator.py index c4c626ad6..3a919c656 100644 --- a/src/autointent/generation/_generator.py +++ b/src/autointent/generation/_generator.py @@ -263,7 +263,9 @@ async def get_structured_output_async( Parsed response as an instance of the provided Pydantic model. """ # Check cache first - cached_result = await self.cache.get_async(messages, output_model, self.generation_params) + cached_result = await self.cache.get_async( + messages, output_model, self.generation_params, self.model_name, self.base_url + ) if cached_result is not None: return cached_result @@ -292,7 +294,7 @@ async def get_structured_output_async( raise RetriesExceededError(max_retries=max_retries, messages=current_messages) # Cache the successful result - await self.cache.set_async(messages, output_model, self.generation_params, res) + await self.cache.set_async(messages, output_model, self.generation_params, res, self.model_name, self.base_url) return res @@ -350,7 +352,7 @@ def get_structured_output_sync( Parsed response as an instance of the provided Pydantic model. """ # Check cache first - cached_result = self.cache.get(messages, output_model, self.generation_params) + cached_result = self.cache.get(messages, output_model, self.generation_params, self.model_name, self.base_url) if cached_result is not None: return cached_result @@ -376,7 +378,7 @@ def get_structured_output_sync( raise RetriesExceededError(max_retries=max_retries, messages=current_messages) # Cache the successful result - self.cache.set(messages, output_model, self.generation_params, res) + self.cache.set(messages, output_model, self.generation_params, res, self.model_name, self.base_url) return res diff --git a/tests/generation/structured_output/test_cache_unit.py b/tests/generation/structured_output/test_cache_unit.py index e526648b1..5c510fa4b 100644 --- a/tests/generation/structured_output/test_cache_unit.py +++ b/tests/generation/structured_output/test_cache_unit.py @@ -36,31 +36,37 @@ def _isolated_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("autointent.generation._cache.user_cache_dir", lambda *_: str(tmp_path)) +MODEL_NAME = "test-model" +BASE_URL: str | None = None + + def test_set_get_memory_roundtrip() -> None: cache = StructuredOutputCache(use_cache=True) result = CacheModel(name="a", value=1) - cache.set(MESSAGES, CacheModel, PARAMS, result) - assert cache.get(MESSAGES, CacheModel, PARAMS) == result + cache.set(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL) + assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result def test_disabled_cache_is_noop() -> None: cache = StructuredOutputCache(use_cache=False) - cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1)) - assert cache.get(MESSAGES, CacheModel, PARAMS) is None + cache.set(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL) + assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None def test_get_misses_for_unknown_key() -> None: cache = StructuredOutputCache(use_cache=True) - assert cache.get(MESSAGES, CacheModel, PARAMS) is None + assert cache.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None def test_get_loads_from_disk_in_fresh_instance() -> None: """A second instance has empty memory and must read the entry back from disk.""" - StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9)) + StructuredOutputCache(use_cache=True).set( + MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL + ) fresh = StructuredOutputCache(use_cache=True) fresh._memory_cache.clear() # force the disk path even if eager load changes - loaded = fresh.get(MESSAGES, CacheModel, PARAMS) + loaded = fresh.get(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) assert isinstance(loaded, CacheModel) assert loaded.value == 9 # disk hit populates the memory cache for next time @@ -69,7 +75,7 @@ def test_get_loads_from_disk_in_fresh_instance() -> None: def test_memory_type_mismatch_evicts() -> None: cache = StructuredOutputCache(use_cache=True) - key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS) + key = cache._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) cache._memory_cache[key] = OtherModel(text="wrong") assert cache._check_memory_cache(key, CacheModel) is None assert key not in cache._memory_cache @@ -79,17 +85,19 @@ def test_memory_type_mismatch_evicts() -> None: async def test_async_set_get_roundtrip() -> None: cache = StructuredOutputCache(use_cache=True) result = CacheModel(name="async", value=7) - await cache.set_async(MESSAGES, CacheModel, PARAMS, result) - assert await cache.get_async(MESSAGES, CacheModel, PARAMS) == result + await cache.set_async(MESSAGES, CacheModel, PARAMS, result, MODEL_NAME, BASE_URL) + assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) == result @pytest.mark.asyncio async def test_async_get_loads_from_disk() -> None: - await StructuredOutputCache(use_cache=True).set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3)) + await StructuredOutputCache(use_cache=True).set_async( + MESSAGES, CacheModel, PARAMS, CacheModel(name="d", value=3), MODEL_NAME, BASE_URL + ) fresh = StructuredOutputCache(use_cache=True) fresh._memory_cache.clear() - loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS) + loaded = await fresh.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) assert isinstance(loaded, CacheModel) assert loaded.value == 3 @@ -97,8 +105,8 @@ async def test_async_get_loads_from_disk() -> None: @pytest.mark.asyncio async def test_async_disabled_cache_is_noop() -> None: cache = StructuredOutputCache(use_cache=False) - await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1)) - assert await cache.get_async(MESSAGES, CacheModel, PARAMS) is None + await cache.set_async(MESSAGES, CacheModel, PARAMS, CacheModel(name="a", value=1), MODEL_NAME, BASE_URL) + assert await cache.get_async(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) is None # --- Regression tests for the on-disk-cache bugs (#326 eager load, #327 eviction) --- @@ -109,10 +117,12 @@ async def test_async_disabled_cache_is_noop() -> None: def test_eager_load_populates_memory_from_disk() -> None: """A fresh instance eagerly batch-loads existing on-disk entries into memory (#326).""" - StructuredOutputCache(use_cache=True).set(MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9)) + StructuredOutputCache(use_cache=True).set( + MESSAGES, CacheModel, PARAMS, CacheModel(name="x", value=9), MODEL_NAME, BASE_URL + ) fresh = StructuredOutputCache(use_cache=True) - key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS) + key = fresh._get_cache_key(MESSAGES, CacheModel, PARAMS, MODEL_NAME, BASE_URL) # populated at construction by the eager load, before any get() call assert key in fresh._memory_cache @@ -137,7 +147,7 @@ def test_disk_type_mismatch_evicts_entry() -> None: """A type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327).""" cache = StructuredOutputCache(use_cache=True) # plant a CacheModel at the key the cache derives for OtherModel inputs - key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS) + key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL) cache._save_to_disk(key, CacheModel(name="x", value=1)) cache._memory_cache.clear() @@ -149,9 +159,84 @@ def test_disk_type_mismatch_evicts_entry() -> None: async def test_async_disk_type_mismatch_evicts_entry() -> None: """Async type-mismatched disk entry is evicted (rmtree) instead of crashing on unlink (#327).""" cache = StructuredOutputCache(use_cache=True) - key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS) + key = cache._get_cache_key(MESSAGES, OtherModel, PARAMS, MODEL_NAME, BASE_URL) await cache._save_to_disk_async(key, CacheModel(name="x", value=1)) cache._memory_cache.clear() assert await cache._load_from_disk_async(key, OtherModel) is None assert not _get_structured_output_cache_path(key).exists() + + +# --- Regression tests for model-identity cache collision (#334) --- + + +def test_different_model_names_do_not_collide() -> None: + """Two generators with different model_name must NOT share a cache entry (#334).""" + result_a = CacheModel(name="from-model-a", value=1) + result_b = CacheModel(name="from-model-b", value=2) + + cache = StructuredOutputCache(use_cache=True) + cache.set(MESSAGES, CacheModel, PARAMS, result_a, model_name="model-a", base_url=None) + cache.set(MESSAGES, CacheModel, PARAMS, result_b, model_name="model-b", base_url=None) + + hit_a = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-a", base_url=None) + hit_b = cache.get(MESSAGES, CacheModel, PARAMS, model_name="model-b", base_url=None) + + assert hit_a == result_a, "model-a should get its own cached value" + assert hit_b == result_b, "model-b must NOT get model-a's value" + + +def test_different_base_urls_do_not_collide() -> None: + """Two generators with different base_url must NOT share a cache entry (#334).""" + result_x = CacheModel(name="from-url-x", value=10) + result_y = CacheModel(name="from-url-y", value=20) + + cache = StructuredOutputCache(use_cache=True) + cache.set(MESSAGES, CacheModel, PARAMS, result_x, model_name="gpt-4o", base_url="http://host-x/v1") + cache.set(MESSAGES, CacheModel, PARAMS, result_y, model_name="gpt-4o", base_url="http://host-y/v1") + + hit_x = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-x/v1") + hit_y = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host-y/v1") + + assert hit_x == result_x, "host-x should get its own cached value" + assert hit_y == result_y, "host-y must NOT get host-x's value" + + +def test_same_identity_still_hits_cache() -> None: + """Same model_name + base_url + inputs must continue to yield a cache hit (#334).""" + result = CacheModel(name="same", value=42) + + cache = StructuredOutputCache(use_cache=True) + cache.set(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url="http://host/v1") + + hit = cache.get(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url="http://host/v1") + assert hit == result + + +@pytest.mark.asyncio +async def test_async_different_model_names_do_not_collide() -> None: + """Async paths: two model names must NOT collide (#334).""" + result_a = CacheModel(name="async-a", value=1) + result_b = CacheModel(name="async-b", value=2) + + cache = StructuredOutputCache(use_cache=True) + await cache.set_async(MESSAGES, CacheModel, PARAMS, result_a, model_name="async-model-a", base_url=None) + await cache.set_async(MESSAGES, CacheModel, PARAMS, result_b, model_name="async-model-b", base_url=None) + + hit_a = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-a", base_url=None) + hit_b = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="async-model-b", base_url=None) + + assert hit_a == result_a + assert hit_b == result_b + + +@pytest.mark.asyncio +async def test_async_same_identity_still_hits_cache() -> None: + """Async paths: same identity must still yield a hit (#334).""" + result = CacheModel(name="async-same", value=99) + + cache = StructuredOutputCache(use_cache=True) + await cache.set_async(MESSAGES, CacheModel, PARAMS, result, model_name="gpt-4o", base_url=None) + + hit = await cache.get_async(MESSAGES, CacheModel, PARAMS, model_name="gpt-4o", base_url=None) + assert hit == result diff --git a/tests/generation/structured_output/test_caching.py b/tests/generation/structured_output/test_caching.py index a77d85316..a5a12e718 100644 --- a/tests/generation/structured_output/test_caching.py +++ b/tests/generation/structured_output/test_caching.py @@ -96,6 +96,8 @@ async def test_cache_hit( messages=messages, output_model=SimpleModel, generation_params=generator_with_cache.generation_params, + model_name=generator_with_cache.model_name, + base_url=generator_with_cache.base_url, ) assert isinstance(cached_res, SimpleModel) assert cached_res.name == result1.name