Skip to content

Commit ffc80e2

Browse files
committed
Abstract VLM providers into an extensible Strategy Pattern
1 parent cd74ff7 commit ffc80e2

12 files changed

Lines changed: 432 additions & 14 deletions

File tree

.env.example

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,20 @@ HF_TOKEN=your_huggingface_token_here
208208
# Optional — defaults to 12
209209
# GRAPH_MAX_RELATIONSHIPS=12
210210

211+
# ── Vision / Image Captioning (VLM Providers) ──────────────
212+
# Set VISION_PROVIDER to one of: openai | anthropic | gemini | ollama
213+
# Leave unset to use OCR / placeholder only.
214+
# VISION_PROVIDER=openai
215+
216+
# VISION_MODEL=gpt-4o-mini # openai default
217+
# VISION_MODEL=claude-3-haiku-20240307 # anthropic default
218+
# VISION_MODEL=gemini-1.5-flash # gemini default
219+
# VISION_MODEL=llava # ollama default
220+
221+
# OPENAI_API_KEY=sk-...
222+
# ANTHROPIC_API_KEY=sk-ant-...
223+
# GOOGLE_API_KEY=AIza...
224+
# OLLAMA_BASE_URL=http://localhost:11434
211225

212226
# ── ChromaDB (Vector Store) ─────────────────────────────────
213227

backend/app/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,15 @@ class Settings(BaseSettings):
139139
# ── Reranker ─────────────────────────────────────────
140140
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-m3" # Lightweight 384-dim model fine-tuned for relevance ranking
141141
# ── Vision / Image captioning ─────────────────────
142-
VISION_PROVIDER: str | None = None # e.g. 'openai'
143-
VISION_MODEL: str | None = None
142+
# Set to: openai | anthropic | gemini | ollama (or leave None)
143+
VISION_PROVIDER: str | None = None
144+
VISION_MODEL: str | None = None # overrides provider default model
145+
146+
# Provider API keys — only the active provider's key is required
144147
OPENAI_API_KEY: str = ""
148+
ANTHROPIC_API_KEY: str = ""
149+
GOOGLE_API_KEY: str = ""
150+
OLLAMA_BASE_URL: str = "http://localhost:11434"
145151

146152
# ── Workspace Invitation ─────────────────────────
147153
APP_URL: str = "http://localhost:3000"

backend/app/rag/vision.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
"""
1212
import base64
1313
import logging
14+
import app.vision.providers # noqa: F401 — triggers self-registration
15+
from app.vision.registry import get_vision_provider
1416
from io import BytesIO
1517
from typing import Any, Dict, List, Optional
1618

@@ -187,24 +189,45 @@ def _openai_caption(image_bytes: bytes) -> str:
187189

188190
# ── Public API ───────────────────────────────────────────────────────────────
189191

190-
def caption_image(image_bytes: bytes, page: Optional[int] = None) -> str:
191-
"""Generate a caption for a single image (bytes).
192-
193-
Resolution order: OpenAI (if configured) → OCR → placeholder.
194-
"""
195-
def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None = None) -> str | List[str]:
192+
def caption_image(
193+
image_bytes: "bytes | List[bytes]",
194+
page: "int | List[int] | None" = None,
195+
) -> "str | List[str]":
196196
"""Generate a caption for a single image or a batch of images.
197197
198-
Order of operations:
199-
- If a list of image bytes is passed, returns a list of captions.
200-
- If an external VLM provider is configured, attempt to call it.
201-
- Fall back to local OCR (pytesseract) if available.
202-
- Otherwise return a simple placeholder caption including the page number.
198+
Resolution order:
199+
1. Configured VLM provider (set VISION_PROVIDER in .env)
200+
2. Local OCR via pytesseract
201+
3. Placeholder string with page number and dimensions
203202
"""
204203
if isinstance(image_bytes, list):
205-
pages = page if isinstance(page, list) else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes))
204+
pages = (
205+
page if isinstance(page, list)
206+
else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes))
207+
)
206208
return [caption_image(img, pg) for img, pg in zip(image_bytes, pages)]
207209

210+
# Strategy: try the configured VLM provider
211+
provider = get_vision_provider(getattr(settings, "VISION_PROVIDER", None))
212+
if provider is not None:
213+
result = provider.caption(image_bytes)
214+
if result:
215+
return result
216+
217+
# Fallback 1: local OCR
218+
ocr = _ocr_caption(image_bytes)
219+
if ocr:
220+
return ocr
221+
222+
# Fallback 2: placeholder
223+
try:
224+
pix = fitz.Pixmap(image_bytes)
225+
dims = f"{pix.width}x{pix.height} px"
226+
except Exception:
227+
dims = "unknown size"
228+
229+
return f"Figure on page {page} ({dims})." if page else f"Figure ({dims})."
230+
208231
# Placeholder for provider-based captioning (e.g., OpenAI / LLaVA hooks)
209232
provider = getattr(settings, "VISION_PROVIDER", None)
210233

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Tests for the VLM provider Strategy Pattern (issue #592)."""
2+
from unittest.mock import MagicMock, patch
3+
import pytest
4+
5+
from app.vision.base import BaseVisionProvider
6+
from app.vision.registry import _REGISTRY, get_vision_provider, register_provider
7+
8+
9+
class TestBaseVisionProvider:
10+
def test_cannot_instantiate_abstract_class(self):
11+
with pytest.raises(TypeError):
12+
BaseVisionProvider()
13+
14+
def test_concrete_subclass_works(self):
15+
class Dummy(BaseVisionProvider):
16+
def caption(self, image_bytes: bytes) -> str:
17+
return "dummy"
18+
assert Dummy().caption(b"x") == "dummy"
19+
20+
21+
class TestRegistry:
22+
def setup_method(self):
23+
self._original = dict(_REGISTRY)
24+
25+
def teardown_method(self):
26+
_REGISTRY.clear()
27+
_REGISTRY.update(self._original)
28+
29+
def test_register_and_retrieve(self):
30+
class FakeProvider(BaseVisionProvider):
31+
def caption(self, image_bytes: bytes) -> str:
32+
return "fake"
33+
register_provider("fake", FakeProvider)
34+
assert get_vision_provider("fake") is not None
35+
36+
def test_case_insensitive(self):
37+
class P(BaseVisionProvider):
38+
def caption(self, image_bytes: bytes) -> str:
39+
return ""
40+
register_provider("UPPER", P)
41+
assert get_vision_provider("upper") is not None
42+
43+
def test_unknown_returns_none(self):
44+
assert get_vision_provider("doesnotexist") is None
45+
46+
def test_none_returns_none(self):
47+
assert get_vision_provider(None) is None
48+
49+
def test_broken_init_returns_none(self):
50+
class Broken(BaseVisionProvider):
51+
def __init__(self): raise RuntimeError("fail")
52+
def caption(self, image_bytes: bytes) -> str: return ""
53+
register_provider("broken", Broken)
54+
assert get_vision_provider("broken") is None
55+
56+
57+
class TestCaptionImage:
58+
def test_uses_provider_when_configured(self):
59+
from app.rag.vision import caption_image
60+
61+
class StubProvider(BaseVisionProvider):
62+
def caption(self, image_bytes: bytes) -> str:
63+
return "stub caption"
64+
65+
with patch("app.rag.vision.get_vision_provider", return_value=StubProvider()):
66+
assert caption_image(b"img", page=1) == "stub caption"
67+
68+
def test_falls_back_to_ocr(self):
69+
from app.rag.vision import caption_image
70+
71+
class EmptyProvider(BaseVisionProvider):
72+
def caption(self, image_bytes: bytes) -> str:
73+
return ""
74+
75+
with patch("app.rag.vision.get_vision_provider", return_value=EmptyProvider()):
76+
with patch("app.rag.vision._ocr_caption", return_value="ocr text"):
77+
assert caption_image(b"img", page=1) == "ocr text"
78+
79+
def test_falls_back_to_placeholder(self):
80+
from app.rag.vision import caption_image
81+
82+
with patch("app.rag.vision.get_vision_provider", return_value=None):
83+
with patch("app.rag.vision._ocr_caption", return_value=""):
84+
result = caption_image(b"img", page=3)
85+
assert "page 3" in result
86+
87+
def test_batch_mode(self):
88+
from app.rag.vision import caption_image
89+
90+
with patch("app.rag.vision.get_vision_provider", return_value=None):
91+
with patch("app.rag.vision._ocr_caption", return_value=""):
92+
results = caption_image([b"img1", b"img2"], page=[1, 2])
93+
assert isinstance(results, list) and len(results) == 2

backend/vision/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Vision package: pluggable VLM provider strategy for image captioning."""
2+
from app.vision.registry import get_vision_provider, register_provider
3+
from app.vision.base import BaseVisionProvider
4+
5+
__all__ = ["BaseVisionProvider", "get_vision_provider", "register_provider"]

backend/vision/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Abstract base class that every VLM provider must implement."""
2+
from abc import ABC, abstractmethod
3+
4+
5+
class BaseVisionProvider(ABC):
6+
"""Strategy interface for Vision-Language Model providers."""
7+
8+
@abstractmethod
9+
def caption(self, image_bytes: bytes) -> str:
10+
"""Generate a one-sentence caption for the given image.
11+
12+
Returns a non-empty string, or empty string on failure (so caller can fall back).
13+
"""
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Auto-registers all built-in providers on import."""
2+
from app.vision.providers import ( # noqa: F401
3+
openai_provider,
4+
anthropic_provider,
5+
gemini_provider,
6+
ollama_provider,
7+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Anthropic Claude Vision provider.
2+
Activated when VISION_PROVIDER=anthropic and ANTHROPIC_API_KEY is set.
3+
"""
4+
import base64
5+
import logging
6+
7+
from app.config import get_settings
8+
from app.vision.base import BaseVisionProvider
9+
from app.vision.registry import register_provider
10+
11+
logger = logging.getLogger(__name__)
12+
settings = get_settings()
13+
14+
_CAPTION_PROMPT = (
15+
"Describe this figure or diagram in one concise sentence "
16+
"suitable for use as a search index caption."
17+
)
18+
19+
20+
class AnthropicVisionProvider(BaseVisionProvider):
21+
22+
def __init__(self) -> None:
23+
self._api_key: str = getattr(settings, "ANTHROPIC_API_KEY", "")
24+
self._model: str = getattr(settings, "VISION_MODEL", None) or "claude-3-haiku-20240307"
25+
if not self._api_key:
26+
raise ValueError(
27+
"ANTHROPIC_API_KEY must be set when VISION_PROVIDER=anthropic."
28+
)
29+
30+
def caption(self, image_bytes: bytes) -> str:
31+
try:
32+
import anthropic
33+
except ImportError:
34+
logger.error("Run: pip install anthropic")
35+
return ""
36+
37+
try:
38+
client = anthropic.Anthropic(api_key=self._api_key)
39+
b64 = base64.b64encode(image_bytes).decode("utf-8")
40+
message = client.messages.create(
41+
model=self._model,
42+
max_tokens=120,
43+
messages=[{
44+
"role": "user",
45+
"content": [
46+
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": b64}},
47+
{"type": "text", "text": _CAPTION_PROMPT},
48+
],
49+
}],
50+
)
51+
content = message.content
52+
if not content:
53+
return ""
54+
text_block = next((b for b in content if getattr(b, "type", None) == "text"), None)
55+
return text_block.text.strip() if text_block else ""
56+
except Exception as exc:
57+
logger.debug("Anthropic vision caption failed: %s", exc)
58+
return ""
59+
60+
61+
register_provider("anthropic", AnthropicVisionProvider)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Google Gemini Vision provider.
2+
Activated when VISION_PROVIDER=gemini and GOOGLE_API_KEY is set.
3+
"""
4+
import logging
5+
6+
from app.config import get_settings
7+
from app.vision.base import BaseVisionProvider
8+
from app.vision.registry import register_provider
9+
10+
logger = logging.getLogger(__name__)
11+
settings = get_settings()
12+
13+
_CAPTION_PROMPT = (
14+
"Describe this figure or diagram in one concise sentence "
15+
"suitable for use as a search index caption."
16+
)
17+
18+
19+
class GeminiVisionProvider(BaseVisionProvider):
20+
21+
def __init__(self) -> None:
22+
self._api_key: str = getattr(settings, "GOOGLE_API_KEY", "")
23+
self._model: str = getattr(settings, "VISION_MODEL", None) or "gemini-1.5-flash"
24+
if not self._api_key:
25+
raise ValueError(
26+
"GOOGLE_API_KEY must be set when VISION_PROVIDER=gemini."
27+
)
28+
29+
def caption(self, image_bytes: bytes) -> str:
30+
try:
31+
import google.generativeai as genai
32+
except ImportError:
33+
logger.error("Run: pip install google-generativeai")
34+
return ""
35+
36+
try:
37+
from io import BytesIO
38+
import PIL.Image
39+
genai.configure(api_key=self._api_key)
40+
model = genai.GenerativeModel(self._model)
41+
image = PIL.Image.open(BytesIO(image_bytes))
42+
response = model.generate_content([_CAPTION_PROMPT, image])
43+
text = getattr(response, "text", None)
44+
return text.strip() if text else ""
45+
except Exception as exc:
46+
logger.debug("Gemini vision caption failed: %s", exc)
47+
return ""
48+
49+
50+
register_provider("gemini", GeminiVisionProvider)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Ollama / LLaVA local Vision provider.
2+
Activated when VISION_PROVIDER=ollama. No API key needed.
3+
Make sure the model is pulled first: ollama pull llava
4+
"""
5+
import base64
6+
import logging
7+
8+
from app.config import get_settings
9+
from app.vision.base import BaseVisionProvider
10+
from app.vision.registry import register_provider
11+
12+
logger = logging.getLogger(__name__)
13+
settings = get_settings()
14+
15+
_CAPTION_PROMPT = (
16+
"Describe this figure or diagram in one concise sentence "
17+
"suitable for use as a search index caption."
18+
)
19+
20+
21+
class OllamaVisionProvider(BaseVisionProvider):
22+
23+
def __init__(self) -> None:
24+
self._base_url: str = (
25+
getattr(settings, "OLLAMA_BASE_URL", None) or "http://localhost:11434"
26+
).rstrip("/")
27+
self._model: str = getattr(settings, "VISION_MODEL", None) or "llava"
28+
29+
def caption(self, image_bytes: bytes) -> str:
30+
try:
31+
import httpx
32+
except ImportError:
33+
logger.error("Run: pip install httpx")
34+
return ""
35+
36+
try:
37+
b64 = base64.b64encode(image_bytes).decode("utf-8")
38+
response = httpx.post(
39+
f"{self._base_url}/api/generate",
40+
json={"model": self._model, "prompt": _CAPTION_PROMPT, "images": [b64], "stream": False},
41+
timeout=60.0,
42+
)
43+
response.raise_for_status()
44+
text = response.json().get("response", "")
45+
return text.strip() if text else ""
46+
except Exception as exc:
47+
logger.debug("Ollama vision caption failed: %s", exc)
48+
return ""
49+
50+
51+
register_provider("ollama", OllamaVisionProvider)

0 commit comments

Comments
 (0)