Skip to content

Commit adde318

Browse files
authored
fix: Fix types in the Google Gen AI embedders (#1916)
* Types cleanup * Add more tests * Fix more tests * Remove todo * Fix test * Fix tests
1 parent 4739bc9 commit adde318

4 files changed

Lines changed: 53 additions & 16 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class GoogleGenAIDocumentEmbedder:
2323
2424
```python
2525
from haystack import Document
26-
from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder
26+
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder
2727
2828
doc = Document(content="I love pizza!")
2929
@@ -48,7 +48,7 @@ def __init__(
4848
meta_fields_to_embed: Optional[List[str]] = None,
4949
embedding_separator: str = "\n",
5050
config: Optional[Dict[str, Any]] = None,
51-
):
51+
) -> None:
5252
"""
5353
Creates an GoogleGenAIDocumentEmbedder component.
5454
@@ -139,32 +139,40 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
139139

140140
return texts_to_embed
141141

142-
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
142+
def _embed_batch(
143+
self, texts_to_embed: List[str], batch_size: int
144+
) -> Tuple[List[Optional[List[float]]], Dict[str, Any]]:
143145
"""
144146
Embed a list of texts in batches.
145147
"""
148+
resolved_config = types.EmbedContentConfig(**self._config) if self._config else None
146149

147150
all_embeddings = []
148151
meta: Dict[str, Any] = {}
149152
for batch in tqdm(
150153
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
151154
):
152155
args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
153-
if self._config:
154-
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
156+
if resolved_config:
157+
args["config"] = resolved_config
155158

156159
response = self._client.models.embed_content(**args)
157160

158-
embeddings = [el.values for el in response.embeddings]
159-
all_embeddings.extend(embeddings)
161+
embeddings = []
162+
if response.embeddings:
163+
for el in response.embeddings:
164+
embeddings.append(el.values if el.values else None)
165+
all_embeddings.extend(embeddings)
166+
else:
167+
all_embeddings.extend([None] * len(batch))
160168

161169
if "model" not in meta:
162170
meta["model"] = self._model
163171

164172
return all_embeddings, meta
165173

166174
@component.output_types(documents=List[Document], meta=Dict[str, Any])
167-
def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
175+
def run(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dict[str, Any]]:
168176
"""
169177
Embeds a list of documents.
170178
@@ -185,6 +193,7 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict
185193

186194
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
187195

196+
meta: Dict[str, Any]
188197
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size)
189198

190199
for doc, emb in zip(documents, embeddings):

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
prefix: str = "",
4545
suffix: str = "",
4646
config: Optional[Dict[str, Any]] = None,
47-
):
47+
) -> None:
4848
"""
4949
Creates an GoogleGenAITextEmbedder component.
5050
@@ -119,7 +119,8 @@ def _prepare_input(self, text: str) -> Dict[str, Any]:
119119
return kwargs
120120

121121
def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]:
122-
return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}}
122+
embedding = result.embeddings[0].values if result.embeddings else []
123+
return {"embedding": embedding, "meta": {"model": self._model_name}}
123124

124125
@component.output_types(embedding=List[float], meta=Dict[str, Any])
125126
def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]:

integrations/google_genai/tests/test_document_embedder.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,35 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
138138
},
139139
}
140140

141+
def test_from_dict(self, monkeypatch):
142+
data = {
143+
"type": (
144+
"haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder"
145+
),
146+
"init_parameters": {
147+
"model": "text-embedding-004",
148+
"prefix": "",
149+
"suffix": "",
150+
"batch_size": 32,
151+
"progress_bar": True,
152+
"meta_fields_to_embed": [],
153+
"embedding_separator": "\n",
154+
"api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True},
155+
"config": {"task_type": "SEMANTIC_SIMILARITY"},
156+
},
157+
}
158+
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
159+
embedder = GoogleGenAIDocumentEmbedder.from_dict(data)
160+
assert embedder._api_key.resolve_value() == "fake-api-key"
161+
assert embedder._model == "text-embedding-004"
162+
assert embedder._prefix == ""
163+
assert embedder._suffix == ""
164+
assert embedder._batch_size == 32
165+
assert embedder._progress_bar is True
166+
assert embedder._meta_fields_to_embed == []
167+
assert embedder._embedding_separator == "\n"
168+
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}
169+
141170
def test_prepare_texts_to_embed_w_metadata(self):
142171
documents = [
143172
Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"})
@@ -204,6 +233,6 @@ def test_run(self):
204233
assert len(doc.embedding) == 768
205234
assert all(isinstance(x, float) for x in doc.embedding)
206235

207-
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
208-
"The model name does not contain 'text' and '004'"
209-
)
236+
assert result["documents"][0].meta == {"topic": "Cuisine"}
237+
assert result["documents"][1].meta == {"topic": "ML"}
238+
assert result["meta"] == {"model": model}

integrations/google_genai/tests/test_text_embedder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,4 @@ def test_run(self):
160160
assert len(result["embedding"]) == 768
161161
assert all(isinstance(x, float) for x in result["embedding"])
162162

163-
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
164-
"The model name does not contain 'text' and '004'"
165-
)
163+
assert result["meta"] == {"model": model}

0 commit comments

Comments
 (0)