Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions aperag/llm/embed/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import hashlib
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Sequence, Tuple
from typing import Dict, List, Sequence, Tuple

import litellm

Expand Down Expand Up @@ -226,19 +226,22 @@ async def aembed_image(self, image_bytes: bytes, alt_text: str = "") -> List[flo
def _embed_image_via_litellm(self, *, image_bytes: bytes, alt_text: str) -> List[float]:
"""Underlying LiteLLM multimodal embedding call.

Encodes the image as base64 data URL + builds the LiteLLM
``input`` payload. Provider-specific input shape variations are
Wave 6 follow-up (per §K.10 Wave 6 backlog cross-cutting
refactor). Currently uses the documented LiteLLM-shaped
``[{"image_url": {"url": "data:..."}}]`` input that
multimodal-capable providers (Voyage / Jina v3 / OpenAI multi-
modal / etc.) accept natively.
Encodes the image as a base64 data URL and dispatches to a
provider-specific input payload shape via
:func:`build_multimodal_input_payload` (Wave 6 task #39 per
§G.2.5.1). Voyage / Jina / Cohere / OpenAI all document
different multimodal embedding wire shapes; the dispatcher
emits the canonical shape for the configured provider and
falls back to the Wave 5 P2 LiteLLM-documented default for
unknown providers.
"""
import base64
import imghdr

from litellm import embedding as litellm_embedding

from aperag.llm.embed.multimodal_input import build_multimodal_input_payload

# Detect MIME type from the image bytes header (avoids relying
# on caller-provided alt_text format hints). Falls back to
# image/jpeg if detection fails — most providers tolerate
Expand All @@ -248,17 +251,17 @@ def _embed_image_via_litellm(self, *, image_bytes: bytes, alt_text: str) -> List
b64 = base64.b64encode(image_bytes).decode("ascii")
data_url = f"data:{mime};base64,{b64}"

input_payload: List[dict[str, Any]] = [{"image_url": {"url": data_url}}]
if alt_text and alt_text.strip():
# Pair the image with text for embedders that accept multi-
# part inputs; embedders that ignore text simply drop it.
input_payload.append({"text": alt_text.strip()})
input_payload = build_multimodal_input_payload(
provider=self.embedding_provider,
image_data_url=data_url,
alt_text=alt_text,
)

response = litellm_embedding(
model=f"{self.embedding_provider}/{self.model}" if self.embedding_provider else self.model,
input=input_payload,
api_key=self.api_key,
api_base=self.base_url,
api_base=self.api_base,
)
# LiteLLM normalises response shape to OpenAI-style; pull the
# embedding from the first (and only) data element.
Expand Down
158 changes: 158 additions & 0 deletions aperag/llm/embed/multimodal_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025 ApeCloud, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provider-specific multimodal embedding input payload builders.

Wave 6 task #39 per `docs/modularization/indexing-redesign-design-pack.md`
§G.2.5.1: the Wave 5 P2 chunk 1 :func:`EmbeddingService._embed_image_via_litellm`
shipped a single LiteLLM-style ``input=[{"image_url": {"url": "data:..."}}, {"text":...}]``
shape that mirrors the OpenAI multimodal *chat completion* request.
That shape is **not** the canonical *embedding* request shape for the
real multimodal-capable providers (Voyage AI, Jina, Cohere, etc.) —
LiteLLM may translate it transparently for some providers but ships
no guarantee. This module dispatches a per-provider payload so the
operator-configured embedder receives the input shape it actually
documents.

The dispatcher matches the ``embedding_provider`` string used by
:class:`EmbeddingService` (LiteLLM provider keyword) and is purposely
**hard-cut** — there is no shim or fall-back to a "compatibility"
shape; unknown providers fall through to the LiteLLM-documented
default that the rest of the embedding stack already uses.
"""

from __future__ import annotations

from typing import Any

# Canonical provider keywords (matches LiteLLM ``custom_llm_provider`` /
# the provider prefix on ``model="<provider>/<model>"`` calls). The
# accepted-aliases tuple lets operators name the provider in any
# common way (LiteLLM itself accepts ``"voyage_ai"`` and ``"voyage"``).
_VOYAGE_ALIASES = ("voyage_ai", "voyageai", "voyage")
_JINA_ALIASES = ("jina_ai", "jinaai", "jina")
_OPENAI_MULTIMODAL_ALIASES = ("openai_multimodal", "openai")
_COHERE_ALIASES = ("cohere",)


def build_multimodal_input_payload(
*,
provider: str | None,
image_data_url: str,
alt_text: str,
) -> list[dict[str, Any]]:
"""Return the ``input=`` payload for ``litellm.embedding(...)``.

``image_data_url`` must already be a base64 data URL
(``data:image/<kind>;base64,<...>``) — the caller is responsible
for MIME detection. ``alt_text`` may be empty; providers that
accept paired text+image inputs use it, others ignore it.

Returns a ``list[dict]`` because every supported provider's
embedding wire shape is ``input: [...]`` even for a single image.
"""

p = (provider or "").strip().lower()
if p in _VOYAGE_ALIASES:
return _voyage_payload(image_data_url, alt_text)
if p in _JINA_ALIASES:
return _jina_payload(image_data_url, alt_text)
if p in _COHERE_ALIASES:
return _cohere_payload(image_data_url, alt_text)
if p in _OPENAI_MULTIMODAL_ALIASES:
return _openai_payload(image_data_url, alt_text)
return _default_payload(image_data_url, alt_text)


def _voyage_payload(image_data_url: str, alt_text: str) -> list[dict[str, Any]]:
"""Voyage AI ``voyage-multimodal-3`` input shape.

Voyage's multimodal embedding API expects each input to be a
``{"content": [...]}`` envelope listing one or more parts; image
parts use ``{"type": "image_base64", "image_base64": "<data url>"}``
and text parts use ``{"type": "text", "text": "..."}``. The text
part is omitted when the caller didn't pass an ``alt_text``.
"""

parts: list[dict[str, Any]] = [{"type": "image_base64", "image_base64": image_data_url}]
if alt_text and alt_text.strip():
parts.append({"type": "text", "text": alt_text.strip()})
return [{"content": parts}]


def _jina_payload(image_data_url: str, alt_text: str) -> list[dict[str, Any]]:
"""Jina (``jina-clip-v2`` / ``jina-embeddings-v4``) input shape.

Jina's multimodal embedding endpoint accepts a list of single-key
dicts: ``{"image": "<data url>"}`` for images and
``{"text": "..."}`` for text. They are embedded jointly so paired
image + alt-text returns a single fused vector.
"""

items: list[dict[str, Any]] = [{"image": image_data_url}]
if alt_text and alt_text.strip():
items.append({"text": alt_text.strip()})
return items


def _cohere_payload(image_data_url: str, alt_text: str) -> list[dict[str, Any]]:
"""Cohere multimodal embedding (``embed-*-v3`` with image input).

Cohere's image embedding uses ``{"image": "<data url>"}`` per item;
text is sent as a separate string entry in the same ``texts``
array. Cohere does not return a fused vector for paired
text+image, so we keep both items independent — the caller can
choose which vector to consume.
"""

items: list[dict[str, Any]] = [{"image": image_data_url}]
if alt_text and alt_text.strip():
items.append({"text": alt_text.strip()})
return items


def _openai_payload(image_data_url: str, alt_text: str) -> list[dict[str, Any]]:
"""OpenAI multimodal embedding input shape (LiteLLM-mapped).

OpenAI's standard ``text-embedding-3-*`` models do not accept
image input — operators that flip ``Model.supports_multimodal_embedding``
on an OpenAI text embedder will hit a runtime error from the
provider. This builder formats the same multipart envelope used
by the OpenAI multimodal *chat* request (the closest documented
shape) so the failure surfaces as a provider-side 4xx with a
clear message instead of a silently-truncated text-only embed.
"""

parts: list[dict[str, Any]] = [{"type": "image_url", "image_url": {"url": image_data_url}}]
if alt_text and alt_text.strip():
parts.append({"type": "text", "text": alt_text.strip()})
return parts


def _default_payload(image_data_url: str, alt_text: str) -> list[dict[str, Any]]:
"""Fallback to the Wave 5 P2 LiteLLM-documented default shape.

Used when ``embedding_provider`` is unset or matches no canonical
alias above. Mirrors the Wave 5 P2 chunk 1 baseline so an
unmapped provider keeps the prior behaviour rather than raising —
operators see the same error path they would have seen pre-#39.
"""

payload: list[dict[str, Any]] = [{"image_url": {"url": image_data_url}}]
if alt_text and alt_text.strip():
payload.append({"text": alt_text.strip()})
return payload


__all__ = ["build_multimodal_input_payload"]
124 changes: 124 additions & 0 deletions tests/unit_test/llm/test_multimodal_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2025 ApeCloud, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for Wave 6 task #39: provider-specific multimodal input
payload dispatcher (`build_multimodal_input_payload`).
"""

from __future__ import annotations

import pytest

from aperag.llm.embed.multimodal_input import build_multimodal_input_payload

_DATA_URL = "data:image/jpeg;base64,AAAA"
_ALT = "two cats on the sofa"


@pytest.mark.parametrize("provider", ["voyage_ai", "voyageai", "voyage", "VOYAGE", " Voyage "])
def test_voyage_payload_uses_content_envelope_with_image_base64_part(provider):
"""Voyage AI multimodal embedding wraps each input in a
``{"content": [parts]}`` envelope — image part must use
``"image_base64"`` discriminator + carry the data URL inline.
"""

payload = build_multimodal_input_payload(provider=provider, image_data_url=_DATA_URL, alt_text=_ALT)

assert isinstance(payload, list) and len(payload) == 1
item = payload[0]
assert "content" in item
parts = item["content"]
image_parts = [p for p in parts if p.get("type") == "image_base64"]
text_parts = [p for p in parts if p.get("type") == "text"]
assert len(image_parts) == 1 and image_parts[0]["image_base64"] == _DATA_URL
assert len(text_parts) == 1 and text_parts[0]["text"] == _ALT


def test_voyage_payload_omits_text_part_when_alt_text_empty():
payload = build_multimodal_input_payload(provider="voyage_ai", image_data_url=_DATA_URL, alt_text="")
parts = payload[0]["content"]
assert all(p.get("type") != "text" for p in parts), "empty alt_text must not produce a text part"


@pytest.mark.parametrize("provider", ["jina_ai", "jinaai", "jina"])
def test_jina_payload_uses_flat_image_and_text_items(provider):
"""Jina (clip-v2 / embeddings-v4) accepts a flat list of single-key
dicts: ``{"image": ...}`` for images, ``{"text": ...}`` for text.
"""

payload = build_multimodal_input_payload(provider=provider, image_data_url=_DATA_URL, alt_text=_ALT)

assert payload == [{"image": _DATA_URL}, {"text": _ALT}]


def test_jina_payload_omits_text_when_alt_text_empty():
payload = build_multimodal_input_payload(provider="jina_ai", image_data_url=_DATA_URL, alt_text=" ")
assert payload == [{"image": _DATA_URL}]


def test_cohere_payload_uses_image_then_text_items():
payload = build_multimodal_input_payload(provider="cohere", image_data_url=_DATA_URL, alt_text=_ALT)
assert payload == [{"image": _DATA_URL}, {"text": _ALT}]


def test_openai_payload_uses_chat_multimodal_envelope():
"""OpenAI's documented multimodal request shape uses
``{"type": "image_url", "image_url": {"url": ...}}`` parts.
"""

payload = build_multimodal_input_payload(provider="openai", image_data_url=_DATA_URL, alt_text=_ALT)

assert payload == [
{"type": "image_url", "image_url": {"url": _DATA_URL}},
{"type": "text", "text": _ALT},
]


def test_openai_payload_alias_openai_multimodal_resolves_same_shape():
a = build_multimodal_input_payload(provider="openai_multimodal", image_data_url=_DATA_URL, alt_text="")
b = build_multimodal_input_payload(provider="openai", image_data_url=_DATA_URL, alt_text="")
assert a == b


def test_unknown_provider_falls_back_to_litellm_default_shape():
"""Unmapped providers preserve the Wave 5 P2 LiteLLM-documented
default shape so the prior behaviour is unchanged.
"""

payload = build_multimodal_input_payload(provider="some-new-provider", image_data_url=_DATA_URL, alt_text=_ALT)
assert payload == [
{"image_url": {"url": _DATA_URL}},
{"text": _ALT},
]


def test_none_provider_resolves_to_default():
payload = build_multimodal_input_payload(provider=None, image_data_url=_DATA_URL, alt_text="x")
assert payload[0] == {"image_url": {"url": _DATA_URL}}


def test_alt_text_whitespace_treated_as_empty_across_providers():
"""A whitespace-only ``alt_text`` must not produce a text part on
any provider — pairing the image with " " would change the cache
key + may confuse the embedder.
"""

for provider in ("voyage_ai", "jina_ai", "cohere", "openai", "unknown"):
payload = build_multimodal_input_payload(provider=provider, image_data_url=_DATA_URL, alt_text=" ")
flat = payload[0].get("content", payload)
text_present = any(
("text" in p and p["text"]) or (p.get("type") == "text")
for p in (flat if isinstance(flat, list) else [flat])
)
assert not text_present, f"{provider}: whitespace alt_text must not produce a text part"
Loading