Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GoogleGenAIDocumentEmbedder:

```python
from haystack import Document
from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder

doc = Document(content="I love pizza!")

Expand All @@ -48,7 +48,7 @@ def __init__(
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
config: Optional[Dict[str, Any]] = None,
):
) -> None:
"""
Creates an GoogleGenAIDocumentEmbedder component.

Expand Down Expand Up @@ -139,32 +139,40 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:

return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
def _embed_batch(
self, texts_to_embed: List[str], batch_size: int
) -> Tuple[List[Optional[List[float]]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
resolved_config = types.EmbedContentConfig(**self._config) if self._config else None

all_embeddings = []
meta: Dict[str, Any] = {}
for batch in tqdm(
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
):
args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
if self._config:
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
if resolved_config:
args["config"] = resolved_config

response = self._client.models.embed_content(**args)

embeddings = [el.values for el in response.embeddings]
all_embeddings.extend(embeddings)
embeddings = []
if response.embeddings:
for el in response.embeddings:
embeddings.append(el.values if el.values else None)
all_embeddings.extend(embeddings)
else:
all_embeddings.extend([None] * len(batch))

if "model" not in meta:
meta["model"] = self._model

return all_embeddings, meta

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
def run(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dict[str, Any]]:
"""
Embeds a list of documents.

Expand All @@ -185,6 +193,7 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

meta: Dict[str, Any]
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size)

for doc, emb in zip(documents, embeddings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
prefix: str = "",
suffix: str = "",
config: Optional[Dict[str, Any]] = None,
):
) -> None:
"""
Creates an GoogleGenAITextEmbedder component.

Expand Down Expand Up @@ -119,7 +119,8 @@ def _prepare_input(self, text: str) -> Dict[str, Any]:
return kwargs

def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]:
return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}}
embedding = result.embeddings[0].values if result.embeddings else []
return {"embedding": embedding, "meta": {"model": self._model_name}}

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]:
Expand Down
35 changes: 32 additions & 3 deletions integrations/google_genai/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,35 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
},
}

def test_from_dict(self, monkeypatch):
data = {
"type": (
"haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder"
),
"init_parameters": {
"model": "text-embedding-004",
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True},
"config": {"task_type": "SEMANTIC_SIMILARITY"},
},
}
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
embedder = GoogleGenAIDocumentEmbedder.from_dict(data)
assert embedder._api_key.resolve_value() == "fake-api-key"
assert embedder._model == "text-embedding-004"
assert embedder._prefix == ""
assert embedder._suffix == ""
assert embedder._batch_size == 32
assert embedder._progress_bar is True
assert embedder._meta_fields_to_embed == []
assert embedder._embedding_separator == "\n"
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}

def test_prepare_texts_to_embed_w_metadata(self):
documents = [
Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"})
Expand Down Expand Up @@ -204,6 +233,6 @@ def test_run(self):
assert len(doc.embedding) == 768
assert all(isinstance(x, float) for x in doc.embedding)

assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
"The model name does not contain 'text' and '004'"
)
assert result["documents"][0].meta == {"topic": "Cuisine"}
assert result["documents"][1].meta == {"topic": "ML"}
assert result["meta"] == {"model": model}
4 changes: 1 addition & 3 deletions integrations/google_genai/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,4 @@ def test_run(self):
assert len(result["embedding"]) == 768
assert all(isinstance(x, float) for x in result["embedding"])

assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
"The model name does not contain 'text' and '004'"
)
assert result["meta"] == {"model": model}
Loading