Skip to content

Commit 411e6ce

Browse files
committed
fix:huggingface-local-embeddings-enabled
1 parent f25b262 commit 411e6ce

4 files changed

Lines changed: 107 additions & 28 deletions

File tree

src/lightspeed_evaluation/core/embedding/ragas.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,44 +20,45 @@ def __init__(self, embedding_manager: EmbeddingManager):
2020
Args:
2121
embedding_manager: Pre-configured EmbeddingManager with validated parameters
2222
"""
23-
config = embedding_manager.config
24-
self.config = config
23+
self.config = embedding_manager.config
2524

2625
# Map provider names to litellm format
27-
provider = config.provider.lower()
28-
model = config.model
26+
provider = self.config.provider.lower()
27+
model = self.config.model
28+
# Get additional provider kwargs
29+
kwargs: dict[str, Any] = {}
30+
if self.config.provider_kwargs:
31+
kwargs.update(self.config.provider_kwargs)
2932

30-
# Build the model string for litellm
31-
# Only OpenAI, Gemini, and HuggingFace are supported
32-
if provider == "openai":
33-
model_str = model # OpenAI models don't need prefix
33+
# Provider-specific configuration
34+
if provider in ["openai", "gemini"]:
35+
logger.debug("Using %s provider with model: %s", provider, model)
36+
actual_provider = (
37+
"litellm" # Litellm provider auto-creates client in embedding_factory
38+
)
3439
elif provider == "huggingface":
35-
model_str = f"huggingface/{model}"
36-
elif provider == "gemini":
37-
model_str = f"gemini/{model}"
40+
# HuggingFace default is use_api=False (local sentence-transformers)
41+
# Only set explicitly if user hasn't overridden in provider_kwargs
42+
kwargs.setdefault("use_api", False)
43+
logger.debug(
44+
"Using HuggingFace provider with model: %s (local=%s)",
45+
model,
46+
not kwargs["use_api"],
47+
)
48+
actual_provider = "huggingface"
3849
else:
39-
logger.error("Unknown embedding provider: %s", config.provider)
40-
raise ConfigurationError(f"Unknown embedding provider {config.provider}")
41-
42-
logger.debug(
43-
"Using embedding provider: %s with model: %s -> %s",
44-
provider,
45-
model,
46-
model_str,
47-
)
48-
49-
# Get additional provider kwargs
50-
kwargs: dict[str, Any] = {}
51-
if config.provider_kwargs:
52-
kwargs.update(config.provider_kwargs)
50+
logger.error("Unknown embedding provider: %s", self.config.provider)
51+
raise ConfigurationError(
52+
f"Unknown embedding provider {self.config.provider}"
53+
)
5354

54-
# Create embeddings using ragas 0.4+ embedding_factory with litellm
55+
# Create embeddings using ragas 0.4+ embedding_factory
5556
# Cast to BaseRagasEmbedding as embedding_factory returns union type
5657
self.embeddings: BaseRagasEmbedding = cast(
5758
BaseRagasEmbedding,
5859
embedding_factory(
59-
"litellm",
60-
model=model_str,
60+
provider=actual_provider,
61+
model=model,
6162
**kwargs,
6263
),
6364
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for embedding module."""
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Fixtures for embedding unit tests."""
2+
3+
import pytest
4+
from pytest_mock import MockerFixture, MockType
5+
6+
7+
@pytest.fixture
8+
def mock_embedding_factory(mocker: MockerFixture) -> MockType:
9+
"""Mock embedding_factory for ragas embedding tests."""
10+
return mocker.patch("lightspeed_evaluation.core.embedding.ragas.embedding_factory")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Unit tests for RagasEmbeddingManager."""
2+
3+
import pytest
4+
from pytest_mock import MockerFixture, MockType
5+
6+
from lightspeed_evaluation.core.embedding.manager import EmbeddingManager
7+
from lightspeed_evaluation.core.embedding.ragas import RagasEmbeddingManager
8+
from lightspeed_evaluation.core.models import EmbeddingConfig
9+
10+
11+
class TestRagasEmbeddingManagerProviderRouting:
12+
"""Test provider routing logic in RagasEmbeddingManager."""
13+
14+
def test_huggingface_uses_native_provider(
15+
self, mock_embedding_factory: MockType
16+
) -> None:
17+
"""Verify HuggingFace uses native provider, not litellm.
18+
19+
HuggingFace embeddings should use the 'huggingface' provider directly
20+
to support local sentence-transformers models with use_api=False.
21+
"""
22+
config = EmbeddingConfig(
23+
provider="huggingface",
24+
model="sentence-transformers/all-MiniLM-L6-v2",
25+
)
26+
embedding_manager = EmbeddingManager(config)
27+
28+
RagasEmbeddingManager(embedding_manager)
29+
30+
mock_embedding_factory.assert_called_once_with(
31+
provider="huggingface",
32+
model="sentence-transformers/all-MiniLM-L6-v2",
33+
use_api=False,
34+
)
35+
36+
@pytest.mark.parametrize(
37+
"provider,model",
38+
[
39+
("openai", "text-embedding-3-small"),
40+
("gemini", "text-embedding-004"),
41+
],
42+
)
43+
def test_cloud_providers_use_litellm(
44+
self,
45+
mocker: MockerFixture,
46+
mock_embedding_factory: MockType,
47+
provider: str,
48+
model: str,
49+
) -> None:
50+
"""Verify OpenAI and Gemini use litellm provider.
51+
52+
Cloud providers (OpenAI, Gemini) should route through 'litellm'
53+
which auto-creates the appropriate client in embedding_factory.
54+
"""
55+
mocker.patch(
56+
"lightspeed_evaluation.core.embedding.manager.validate_provider_env"
57+
)
58+
59+
config = EmbeddingConfig(provider=provider, model=model)
60+
embedding_manager = EmbeddingManager(config)
61+
62+
RagasEmbeddingManager(embedding_manager)
63+
64+
mock_embedding_factory.assert_called_once_with(
65+
provider="litellm",
66+
model=model,
67+
)

0 commit comments

Comments
 (0)