Skip to content

Commit 80fff83

Browse files
authored
feat(celery Wave 6 #37): wire EmbeddingService.embed_image into application cache (#1735)
Wave 5 P2 chunk 4 (#1733) landed the canonical multimodal embedding API surface (`EmbeddingService.embed_image(image_bytes, alt_text)`) but left the call un-cached — every vision modality embed now hits the LiteLLM provider, even for the same image. This wires it into the canonical `aperag.cache.NAMESPACE_EMBEDDING` infra (PR #1734) mirroring the existing text `_embed_batch` pattern. Cache key shape (per `aperag/cache/README.md` no-raw-bytes policy): { "kind": "image", "provider": ..., "model": ..., "api_base": ..., "api_key_hash": sha256(api_key), "file_hash": sha256(image_bytes), "alt_text": ..., "multimodal": True, } Image bytes are identified by their sha256 hex digest so the Redis key stays bounded; alt_text is part of the key because providers that accept paired text+image inputs return a different vector when the textual hint changes (alt_text="" collapses to one key for image-only callers). Tests ----- New `tests/unit_test/llm/test_embed_image_cache.py` (7 tests): * identical (bytes, alt_text) → second call hits cache (no upstream) * same bytes + different alt_text → distinct keys, both compute * different bytes + same alt_text → distinct keys, both compute * key shape uses sha256 file_hash, raw bytes never appear in key * `caching=False` bypasses cache (always upstream) * `multimodal=False` raises EmbeddingError (defense-in-depth) * empty image_bytes raises EmptyTextError Full unit suite: 1022 passed, 29 skipped, ruff + format clean. Out of scope (per task #37 boundary) ------------------------------------ Provider-specific multimodal embedder format variations (Voyage / Jina v3 / OpenAI multimodal SDK input shapes) stay on task #39 per PM dispatch + simple-stable directive (`feedback_simple_stable_zero _maintenance.md`).
1 parent 7c9e4c3 commit 80fff83

2 files changed

Lines changed: 237 additions & 7 deletions

File tree

aperag/llm/embed/embedding_service.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,26 @@ def embed_image(self, image_bytes: bytes, alt_text: str = "") -> List[float]:
199199
)
200200
if not image_bytes:
201201
raise EmptyTextError(1)
202-
try:
203-
return self._embed_image_via_litellm(image_bytes=image_bytes, alt_text=alt_text)
204-
except (EmptyTextError, EmbeddingError):
205-
raise
206-
except Exception as e:
207-
logger.error(f"Image embedding failed: {str(e)}")
208-
raise wrap_litellm_error(e, "embedding", self.embedding_provider, self.model) from e
202+
203+
def _compute() -> List[float]:
204+
try:
205+
return self._embed_image_via_litellm(image_bytes=image_bytes, alt_text=alt_text)
206+
except (EmptyTextError, EmbeddingError):
207+
raise
208+
except Exception as e:
209+
logger.error(f"Image embedding failed: {str(e)}")
210+
raise wrap_litellm_error(e, "embedding", self.embedding_provider, self.model) from e
211+
212+
if not self.caching:
213+
return _compute()
214+
215+
cache = get_sync_application_cache()
216+
return cache.get_or_compute(
217+
namespace=NAMESPACE_EMBEDDING,
218+
key_data=self._cache_key_for_image(image_bytes, alt_text),
219+
compute=_compute,
220+
policy=application_cache_policy(NAMESPACE_EMBEDDING),
221+
)
209222

210223
async def aembed_image(self, image_bytes: bytes, alt_text: str = "") -> List[float]:
211224
return await asyncio.to_thread(self.embed_image, image_bytes, alt_text)
@@ -352,3 +365,25 @@ def _cache_key_for_input(self, text: str) -> dict:
352365
"multimodal": self.multimodal,
353366
"encoding_format": "float",
354367
}
368+
369+
def _cache_key_for_image(self, image_bytes: bytes, alt_text: str) -> dict:
370+
"""Cache key shape for ``embed_image`` (Wave 6 task #37).
371+
372+
Per cache README the namespaced key never embeds raw bytes — the
373+
image is identified by ``sha256(image_bytes)`` so the Redis key
374+
stays bounded. ``alt_text`` is part of the key because providers
375+
that accept paired text+image inputs return a different vector
376+
when the textual hint changes; ``alt_text=""`` collapses to the
377+
same key for image-only callers.
378+
"""
379+
380+
return {
381+
"kind": "image",
382+
"provider": self.embedding_provider,
383+
"model": self.model,
384+
"api_base": self.api_base,
385+
"api_key_hash": hashlib.sha256((self.api_key or "").encode("utf-8")).hexdigest(),
386+
"file_hash": hashlib.sha256(image_bytes).hexdigest(),
387+
"alt_text": alt_text or "",
388+
"multimodal": True,
389+
}
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2025 ApeCloud, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for Wave 6 task #37: ``EmbeddingService.embed_image`` cache wiring.
16+
17+
Pins the contract that ``embed_image`` honours the canonical
18+
application cache (``NAMESPACE_EMBEDDING``):
19+
20+
* identical ``(image_bytes, alt_text)`` calls hit the cache and skip the
21+
upstream LiteLLM call,
22+
* the cache key shape is ``sha256(image_bytes)``-based, never embedding
23+
raw bytes (per `aperag/cache/README.md` policy),
24+
* changing ``alt_text`` (or the image bytes) yields a distinct key,
25+
* ``caching=False`` bypasses the cache and always calls the upstream.
26+
"""
27+
28+
from __future__ import annotations
29+
30+
import hashlib
31+
32+
import pytest
33+
34+
from aperag.cache.application import ApplicationCachePolicy, SyncApplicationCache
35+
from aperag.cache.key import build_cache_key
36+
from aperag.llm.embed.embedding_service import EmbeddingService
37+
from aperag.llm.llm_error_types import EmbeddingError, EmptyTextError
38+
39+
40+
class _FakeSyncBackend:
41+
"""Minimal in-memory backend matching ``SyncApplicationRedisCacheBackend``."""
42+
43+
def __init__(self):
44+
self.store: dict[str, str | bytes] = {}
45+
self.get_calls = 0
46+
self.set_calls = 0
47+
48+
def get(self, key: str):
49+
self.get_calls += 1
50+
return self.store.get(key)
51+
52+
def mget(self, keys):
53+
return [self.get(key) for key in keys]
54+
55+
def set(self, key: str, value: str, ttl_seconds: int) -> None:
56+
self.set_calls += 1
57+
self.store[key] = value
58+
59+
def delete(self, *keys: str) -> int:
60+
for key in keys:
61+
self.store.pop(key, None)
62+
return len(keys)
63+
64+
65+
def _make_cache() -> SyncApplicationCache:
66+
return SyncApplicationCache(
67+
backend=_FakeSyncBackend(),
68+
default_policy=ApplicationCachePolicy(namespace="embedding", ttl_seconds=60, max_value_bytes=4096),
69+
)
70+
71+
72+
def _make_service(*, caching: bool = True) -> EmbeddingService:
73+
return EmbeddingService(
74+
embedding_provider="jina_ai",
75+
embedding_model="jina-embeddings-v4",
76+
embedding_service_url="https://api.jina.ai/v1",
77+
embedding_service_api_key="sk-test",
78+
embedding_max_chunks_in_batch=1,
79+
multimodal=True,
80+
caching=caching,
81+
)
82+
83+
84+
@pytest.fixture
85+
def cache(monkeypatch):
86+
fake = _make_cache()
87+
monkeypatch.setattr("aperag.llm.embed.embedding_service.get_sync_application_cache", lambda: fake)
88+
return fake
89+
90+
91+
def _stub_litellm(monkeypatch, calls: list[bytes]):
92+
def _fake(self, *, image_bytes: bytes, alt_text: str):
93+
calls.append(image_bytes)
94+
return [0.1, 0.2, 0.3]
95+
96+
monkeypatch.setattr(EmbeddingService, "_embed_image_via_litellm", _fake)
97+
98+
99+
def test_embed_image_caches_identical_calls(cache, monkeypatch):
100+
"""Second identical call must return from cache (no second compute)."""
101+
102+
calls: list[bytes] = []
103+
_stub_litellm(monkeypatch, calls)
104+
105+
service = _make_service()
106+
107+
first = service.embed_image(b"\x89PNG\r\n\x1a\nfake-bytes", alt_text="cat")
108+
second = service.embed_image(b"\x89PNG\r\n\x1a\nfake-bytes", alt_text="cat")
109+
110+
assert first == second == [0.1, 0.2, 0.3]
111+
assert len(calls) == 1, "second call should hit cache, not invoke LiteLLM"
112+
113+
114+
def test_embed_image_distinct_alt_text_yields_distinct_keys(cache, monkeypatch):
115+
"""Same image bytes + different alt_text must compute twice (independent cache rows)."""
116+
117+
calls: list[bytes] = []
118+
_stub_litellm(monkeypatch, calls)
119+
120+
service = _make_service()
121+
image_bytes = b"\x89PNG\r\n\x1a\nfake-bytes"
122+
123+
service.embed_image(image_bytes, alt_text="cat")
124+
service.embed_image(image_bytes, alt_text="dog")
125+
126+
assert len(calls) == 2, "alt_text change must miss the cache and re-compute"
127+
128+
129+
def test_embed_image_distinct_bytes_yield_distinct_keys(cache, monkeypatch):
130+
"""Different image bytes (even same alt_text) must compute twice."""
131+
132+
calls: list[bytes] = []
133+
_stub_litellm(monkeypatch, calls)
134+
135+
service = _make_service()
136+
137+
service.embed_image(b"\x89PNG\r\n\x1a\nimage-A", alt_text="x")
138+
service.embed_image(b"\x89PNG\r\n\x1a\nimage-B", alt_text="x")
139+
140+
assert len(calls) == 2, "different image bytes must produce a different cache key"
141+
142+
143+
def test_embed_image_cache_key_uses_sha256_not_raw_bytes(monkeypatch):
144+
"""Per `aperag/cache/README.md` raw bytes must never appear in the
145+
Redis key. The image fingerprint is a sha256 hex digest.
146+
"""
147+
148+
service = _make_service()
149+
image_bytes = b"sensitive-binary-payload"
150+
key_data = service._cache_key_for_image(image_bytes, alt_text="hint")
151+
152+
assert key_data["file_hash"] == hashlib.sha256(image_bytes).hexdigest()
153+
assert key_data["kind"] == "image"
154+
assert key_data["multimodal"] is True
155+
# Build the actual Redis key and verify the raw bytes are absent.
156+
key = build_cache_key("embedding", key_data)
157+
assert "sensitive-binary-payload" not in key
158+
assert key.startswith("aperag:cache:v1:embedding:")
159+
160+
161+
def test_embed_image_cache_disabled_always_calls_upstream(monkeypatch):
162+
"""``caching=False`` must bypass the cache wiring entirely."""
163+
164+
calls: list[bytes] = []
165+
_stub_litellm(monkeypatch, calls)
166+
167+
service = _make_service(caching=False)
168+
169+
service.embed_image(b"img", alt_text="x")
170+
service.embed_image(b"img", alt_text="x")
171+
172+
assert len(calls) == 2, "caching=False must skip the cache and recompute every call"
173+
174+
175+
def test_embed_image_rejects_non_multimodal():
176+
"""Defense-in-depth: ``multimodal=False`` must raise before touching the cache."""
177+
178+
service = EmbeddingService(
179+
embedding_provider="openai",
180+
embedding_model="text-embedding-3-small",
181+
embedding_service_url="https://api.openai.com/v1",
182+
embedding_service_api_key="sk-test",
183+
embedding_max_chunks_in_batch=1,
184+
multimodal=False,
185+
caching=True,
186+
)
187+
188+
with pytest.raises(EmbeddingError):
189+
service.embed_image(b"any", alt_text="x")
190+
191+
192+
def test_embed_image_rejects_empty_bytes():
193+
service = _make_service()
194+
with pytest.raises(EmptyTextError):
195+
service.embed_image(b"", alt_text="x")

0 commit comments

Comments
 (0)