Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions src/lightspeed_evaluation/core/embedding/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,45 @@ def __init__(self, embedding_manager: EmbeddingManager):
Args:
embedding_manager: Pre-configured EmbeddingManager with validated parameters
"""
config = embedding_manager.config
self.config = config
self.config = embedding_manager.config

# Map provider names to litellm format
provider = config.provider.lower()
model = config.model
provider = self.config.provider.lower()
model = self.config.model
# Get additional provider kwargs
kwargs: dict[str, Any] = {}
if self.config.provider_kwargs:
kwargs.update(self.config.provider_kwargs)

# Build the model string for litellm
# Only OpenAI, Gemini, and HuggingFace are supported
if provider == "openai":
model_str = model # OpenAI models don't need prefix
# Provider-specific configuration
if provider in ["openai", "gemini"]:
logger.debug("Using %s provider with model: %s", provider, model)
actual_provider = (
"litellm" # Litellm provider auto-creates client in embedding_factory
)
elif provider == "huggingface":
model_str = f"huggingface/{model}"
elif provider == "gemini":
model_str = f"gemini/{model}"
# HuggingFace default is use_api=False (local sentence-transformers)
# Only set explicitly if user hasn't overridden in provider_kwargs
kwargs.setdefault("use_api", False)
logger.debug(
"Using HuggingFace provider with model: %s (local=%s)",
model,
not kwargs["use_api"],
)
actual_provider = "huggingface"
else:
logger.error("Unknown embedding provider: %s", config.provider)
raise ConfigurationError(f"Unknown embedding provider {config.provider}")

logger.debug(
"Using embedding provider: %s with model: %s -> %s",
provider,
model,
model_str,
)

# Get additional provider kwargs
kwargs: dict[str, Any] = {}
if config.provider_kwargs:
kwargs.update(config.provider_kwargs)
logger.error("Unknown embedding provider: %s", self.config.provider)
raise ConfigurationError(
f"Unknown embedding provider {self.config.provider}"
)

# Create embeddings using ragas 0.4+ embedding_factory with litellm
# Create embeddings using ragas 0.4+ embedding_factory
# Cast to BaseRagasEmbedding as embedding_factory returns union type
self.embeddings: BaseRagasEmbedding = cast(
BaseRagasEmbedding,
embedding_factory(
"litellm",
model=model_str,
provider=actual_provider,
model=model,
**kwargs,
),
)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/core/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for embedding module."""
10 changes: 10 additions & 0 deletions tests/unit/core/embedding/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Fixtures for embedding unit tests."""

import pytest
from pytest_mock import MockerFixture, MockType


@pytest.fixture
def mock_embedding_factory(mocker: MockerFixture) -> MockType:
"""Mock embedding_factory for ragas embedding tests."""
return mocker.patch("lightspeed_evaluation.core.embedding.ragas.embedding_factory")
71 changes: 71 additions & 0 deletions tests/unit/core/embedding/test_ragas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Unit tests for RagasEmbeddingManager."""

import pytest
from pytest_mock import MockerFixture, MockType

from lightspeed_evaluation.core.embedding.manager import EmbeddingManager
from lightspeed_evaluation.core.embedding.ragas import RagasEmbeddingManager
from lightspeed_evaluation.core.models import EmbeddingConfig


class TestRagasEmbeddingManagerProviderRouting:
"""Test provider routing logic in RagasEmbeddingManager."""

def test_huggingface_uses_native_provider(
self, mocker: MockerFixture, mock_embedding_factory: MockType
) -> None:
"""Verify HuggingFace uses native provider, not litellm.

HuggingFace embeddings should use the 'huggingface' provider directly
to support local sentence-transformers models with use_api=False.
"""
mocker.patch(
"lightspeed_evaluation.core.embedding.manager.check_huggingface_available"
)

config = EmbeddingConfig(
provider="huggingface",
model="sentence-transformers/all-MiniLM-L6-v2",
)
embedding_manager = EmbeddingManager(config)

RagasEmbeddingManager(embedding_manager)

mock_embedding_factory.assert_called_once_with(
provider="huggingface",
model="sentence-transformers/all-MiniLM-L6-v2",
use_api=False,
)

@pytest.mark.parametrize(
"provider,model",
[
("openai", "text-embedding-3-small"),
("gemini", "text-embedding-004"),
],
)
def test_cloud_providers_use_litellm(
self,
mocker: MockerFixture,
mock_embedding_factory: MockType,
provider: str,
model: str,
) -> None:
"""Verify OpenAI and Gemini use litellm provider.

Cloud providers (OpenAI, Gemini) should route through 'litellm'
which auto-creates the appropriate client in embedding_factory.
"""
mocker.patch(
"lightspeed_evaluation.core.embedding.manager.validate_provider_env"
)

config = EmbeddingConfig(provider=provider, model=model)
embedding_manager = EmbeddingManager(config)

RagasEmbeddingManager(embedding_manager)

mock_embedding_factory.assert_called_once_with(
provider="litellm",
model=model,
)
Loading