Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,20 @@ HF_TOKEN=your_huggingface_token_here
# Optional — defaults to 12
# GRAPH_MAX_RELATIONSHIPS=12

# ── Vision / Image Captioning (VLM Providers) ──────────────
# Set VISION_PROVIDER to one of: openai | anthropic | gemini | ollama
# Leave unset to use OCR / placeholder only.
# VISION_PROVIDER=openai

# VISION_MODEL=gpt-4o-mini # openai default
# VISION_MODEL=claude-3-haiku-20240307 # anthropic default
# VISION_MODEL=gemini-1.5-flash # gemini default
# VISION_MODEL=llava # ollama default

# OPENAI_API_KEY=sk-...
# ANTHROPIC_API_KEY=sk-ant-...
# GOOGLE_API_KEY=AIza...
# OLLAMA_BASE_URL=http://localhost:11434

# ── ChromaDB (Vector Store) ─────────────────────────────────

Expand Down
10 changes: 8 additions & 2 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,15 @@ class Settings(BaseSettings):
# ── Reranker ─────────────────────────────────────────
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-m3" # Lightweight 384-dim model fine-tuned for relevance ranking
# ── Vision / Image captioning ─────────────────────
VISION_PROVIDER: str | None = None # e.g. 'openai'
VISION_MODEL: str | None = None
# Set to: openai | anthropic | gemini | ollama (or leave None)
VISION_PROVIDER: str | None = None
VISION_MODEL: str | None = None # overrides provider default model

# Provider API keys — only the active provider's key is required
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
OLLAMA_BASE_URL: str = "http://localhost:11434"

# ── Workspace Invitation ─────────────────────────
APP_URL: str = "http://localhost:3000"
Expand Down
47 changes: 35 additions & 12 deletions backend/app/rag/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"""
import base64
import logging
import app.vision.providers # noqa: F401 — triggers self-registration
from app.vision.registry import get_vision_provider
from io import BytesIO
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -187,24 +189,45 @@ def _openai_caption(image_bytes: bytes) -> str:

# ── Public API ───────────────────────────────────────────────────────────────

def caption_image(image_bytes: bytes, page: Optional[int] = None) -> str:
"""Generate a caption for a single image (bytes).

Resolution order: OpenAI (if configured) → OCR → placeholder.
"""
def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None = None) -> str | List[str]:
def caption_image(
image_bytes: "bytes | List[bytes]",
page: "int | List[int] | None" = None,
) -> "str | List[str]":
"""Generate a caption for a single image or a batch of images.

Order of operations:
- If a list of image bytes is passed, returns a list of captions.
- If an external VLM provider is configured, attempt to call it.
- Fall back to local OCR (pytesseract) if available.
- Otherwise return a simple placeholder caption including the page number.
Resolution order:
1. Configured VLM provider (set VISION_PROVIDER in .env)
2. Local OCR via pytesseract
3. Placeholder string with page number and dimensions
"""
if isinstance(image_bytes, list):
pages = page if isinstance(page, list) else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes))
pages = (
page if isinstance(page, list)
else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes))
)
return [caption_image(img, pg) for img, pg in zip(image_bytes, pages)]

# Strategy: try the configured VLM provider
provider = get_vision_provider(getattr(settings, "VISION_PROVIDER", None))
if provider is not None:
result = provider.caption(image_bytes)
if result:
return result

# Fallback 1: local OCR
ocr = _ocr_caption(image_bytes)
if ocr:
return ocr

# Fallback 2: placeholder
try:
pix = fitz.Pixmap(image_bytes)
dims = f"{pix.width}x{pix.height} px"
except Exception:
dims = "unknown size"

return f"Figure on page {page} ({dims})." if page else f"Figure ({dims})."

# Placeholder for provider-based captioning (e.g., OpenAI / LLaVA hooks)
provider = getattr(settings, "VISION_PROVIDER", None)

Expand Down
93 changes: 93 additions & 0 deletions backend/tests/test_vision_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Tests for the VLM provider Strategy Pattern (issue #592)."""
from unittest.mock import MagicMock, patch
import pytest

from app.vision.base import BaseVisionProvider
from app.vision.registry import _REGISTRY, get_vision_provider, register_provider


class TestBaseVisionProvider:
def test_cannot_instantiate_abstract_class(self):
with pytest.raises(TypeError):
BaseVisionProvider()

def test_concrete_subclass_works(self):
class Dummy(BaseVisionProvider):
def caption(self, image_bytes: bytes) -> str:
return "dummy"
assert Dummy().caption(b"x") == "dummy"


class TestRegistry:
def setup_method(self):
self._original = dict(_REGISTRY)

def teardown_method(self):
_REGISTRY.clear()
_REGISTRY.update(self._original)

def test_register_and_retrieve(self):
class FakeProvider(BaseVisionProvider):
def caption(self, image_bytes: bytes) -> str:
return "fake"
register_provider("fake", FakeProvider)
assert get_vision_provider("fake") is not None

def test_case_insensitive(self):
class P(BaseVisionProvider):
def caption(self, image_bytes: bytes) -> str:
return ""
register_provider("UPPER", P)
assert get_vision_provider("upper") is not None

def test_unknown_returns_none(self):
assert get_vision_provider("doesnotexist") is None

def test_none_returns_none(self):
assert get_vision_provider(None) is None

def test_broken_init_returns_none(self):
class Broken(BaseVisionProvider):
def __init__(self): raise RuntimeError("fail")
def caption(self, image_bytes: bytes) -> str: return ""
register_provider("broken", Broken)
assert get_vision_provider("broken") is None


class TestCaptionImage:
def test_uses_provider_when_configured(self):
from app.rag.vision import caption_image

class StubProvider(BaseVisionProvider):
def caption(self, image_bytes: bytes) -> str:
return "stub caption"

with patch("app.rag.vision.get_vision_provider", return_value=StubProvider()):
assert caption_image(b"img", page=1) == "stub caption"

def test_falls_back_to_ocr(self):
from app.rag.vision import caption_image

class EmptyProvider(BaseVisionProvider):
def caption(self, image_bytes: bytes) -> str:
return ""

with patch("app.rag.vision.get_vision_provider", return_value=EmptyProvider()):
with patch("app.rag.vision._ocr_caption", return_value="ocr text"):
assert caption_image(b"img", page=1) == "ocr text"

def test_falls_back_to_placeholder(self):
from app.rag.vision import caption_image

with patch("app.rag.vision.get_vision_provider", return_value=None):
with patch("app.rag.vision._ocr_caption", return_value=""):
result = caption_image(b"img", page=3)
assert "page 3" in result

def test_batch_mode(self):
from app.rag.vision import caption_image

with patch("app.rag.vision.get_vision_provider", return_value=None):
with patch("app.rag.vision._ocr_caption", return_value=""):
results = caption_image([b"img1", b"img2"], page=[1, 2])
assert isinstance(results, list) and len(results) == 2
5 changes: 5 additions & 0 deletions backend/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Vision package: pluggable VLM provider strategy for image captioning."""
from app.vision.registry import get_vision_provider, register_provider
from app.vision.base import BaseVisionProvider

__all__ = ["BaseVisionProvider", "get_vision_provider", "register_provider"]
13 changes: 13 additions & 0 deletions backend/vision/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Abstract base class that every VLM provider must implement."""
from abc import ABC, abstractmethod


class BaseVisionProvider(ABC):
"""Strategy interface for Vision-Language Model providers."""

@abstractmethod
def caption(self, image_bytes: bytes) -> str:
"""Generate a one-sentence caption for the given image.

Returns a non-empty string, or empty string on failure (so caller can fall back).
"""
7 changes: 7 additions & 0 deletions backend/vision/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Auto-registers all built-in providers on import."""
from app.vision.providers import ( # noqa: F401
openai_provider,
anthropic_provider,
gemini_provider,
ollama_provider,
)
61 changes: 61 additions & 0 deletions backend/vision/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Anthropic Claude Vision provider.
Activated when VISION_PROVIDER=anthropic and ANTHROPIC_API_KEY is set.
"""
import base64
import logging

from app.config import get_settings
from app.vision.base import BaseVisionProvider
from app.vision.registry import register_provider

logger = logging.getLogger(__name__)
settings = get_settings()

_CAPTION_PROMPT = (
"Describe this figure or diagram in one concise sentence "
"suitable for use as a search index caption."
)


class AnthropicVisionProvider(BaseVisionProvider):

def __init__(self) -> None:
self._api_key: str = getattr(settings, "ANTHROPIC_API_KEY", "")
self._model: str = getattr(settings, "VISION_MODEL", None) or "claude-3-haiku-20240307"
if not self._api_key:
raise ValueError(
"ANTHROPIC_API_KEY must be set when VISION_PROVIDER=anthropic."
)

def caption(self, image_bytes: bytes) -> str:
try:
import anthropic
except ImportError:
logger.error("Run: pip install anthropic")
return ""

try:
client = anthropic.Anthropic(api_key=self._api_key)
b64 = base64.b64encode(image_bytes).decode("utf-8")
message = client.messages.create(
model=self._model,
max_tokens=120,
messages=[{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": b64}},
{"type": "text", "text": _CAPTION_PROMPT},
],
}],
)
content = message.content
if not content:
return ""
text_block = next((b for b in content if getattr(b, "type", None) == "text"), None)
return text_block.text.strip() if text_block else ""
except Exception as exc:
logger.debug("Anthropic vision caption failed: %s", exc)
return ""


register_provider("anthropic", AnthropicVisionProvider)
50 changes: 50 additions & 0 deletions backend/vision/providers/gemini_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Google Gemini Vision provider.
Activated when VISION_PROVIDER=gemini and GOOGLE_API_KEY is set.
"""
import logging

from app.config import get_settings
from app.vision.base import BaseVisionProvider
from app.vision.registry import register_provider

logger = logging.getLogger(__name__)
settings = get_settings()

_CAPTION_PROMPT = (
"Describe this figure or diagram in one concise sentence "
"suitable for use as a search index caption."
)


class GeminiVisionProvider(BaseVisionProvider):

def __init__(self) -> None:
self._api_key: str = getattr(settings, "GOOGLE_API_KEY", "")
self._model: str = getattr(settings, "VISION_MODEL", None) or "gemini-1.5-flash"
if not self._api_key:
raise ValueError(
"GOOGLE_API_KEY must be set when VISION_PROVIDER=gemini."
)

def caption(self, image_bytes: bytes) -> str:
try:
import google.generativeai as genai
except ImportError:
logger.error("Run: pip install google-generativeai")
return ""

try:
from io import BytesIO
import PIL.Image
genai.configure(api_key=self._api_key)
model = genai.GenerativeModel(self._model)
image = PIL.Image.open(BytesIO(image_bytes))
response = model.generate_content([_CAPTION_PROMPT, image])
text = getattr(response, "text", None)
return text.strip() if text else ""
except Exception as exc:
logger.debug("Gemini vision caption failed: %s", exc)
return ""


register_provider("gemini", GeminiVisionProvider)
51 changes: 51 additions & 0 deletions backend/vision/providers/ollama_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Ollama / LLaVA local Vision provider.
Activated when VISION_PROVIDER=ollama. No API key needed.
Make sure the model is pulled first: ollama pull llava
"""
import base64
import logging

from app.config import get_settings
from app.vision.base import BaseVisionProvider
from app.vision.registry import register_provider

logger = logging.getLogger(__name__)
settings = get_settings()

_CAPTION_PROMPT = (
"Describe this figure or diagram in one concise sentence "
"suitable for use as a search index caption."
)


class OllamaVisionProvider(BaseVisionProvider):

def __init__(self) -> None:
self._base_url: str = (
getattr(settings, "OLLAMA_BASE_URL", None) or "http://localhost:11434"
).rstrip("/")
self._model: str = getattr(settings, "VISION_MODEL", None) or "llava"

def caption(self, image_bytes: bytes) -> str:
try:
import httpx
except ImportError:
logger.error("Run: pip install httpx")
return ""

try:
b64 = base64.b64encode(image_bytes).decode("utf-8")
response = httpx.post(
f"{self._base_url}/api/generate",
json={"model": self._model, "prompt": _CAPTION_PROMPT, "images": [b64], "stream": False},
timeout=60.0,
)
response.raise_for_status()
text = response.json().get("response", "")
return text.strip() if text else ""
except Exception as exc:
logger.debug("Ollama vision caption failed: %s", exc)
return ""


register_provider("ollama", OllamaVisionProvider)
Loading
Loading