44
55import os
66import random
7+ from unittest .mock import AsyncMock , MagicMock
78
9+ import numpy as np
810import pytest
911from haystack import Document
1012from haystack .utils .auth import Secret
1113
12- from haystack_integrations .components .embedders .google_genai import GoogleGenAIDocumentEmbedder
14+ from haystack_integrations .components .embedders .google_genai import GoogleGenAIDocumentEmbedder , GoogleGenAITextEmbedder
1315
1416
1517def mock_google_response (contents : list [str ], model : str = "gemini-embedding-001" , ** kwargs ) -> dict :
@@ -256,20 +258,66 @@ async def mock_embed_batch_async(texts_to_embed, batch_size):
256258 for doc_with_embedding in result ["documents" ]:
257259 assert doc_with_embedding .embedding == [0.1 , 0.2 , 0.3 ]
258260
261+ def test_embed_batch_passes_full_texts (self , monkeypatch ):
262+ monkeypatch .setenv ("GOOGLE_API_KEY" , "fake-api-key" )
263+ embedder = GoogleGenAIDocumentEmbedder (batch_size = 2 )
264+
265+ texts = ["first document text" , "second document text" , "third document text" ]
266+
267+ mock_embedding = MagicMock ()
268+ mock_embedding .values = [0.1 , 0.2 , 0.3 ]
269+
270+ mock_response = MagicMock ()
271+ mock_response .embeddings = [mock_embedding ]
272+
273+ embedder ._client = MagicMock ()
274+ embedder ._client .models .embed_content .return_value = mock_response
275+
276+ embedder ._embed_batch (texts , batch_size = 2 )
277+
278+ calls = embedder ._client .models .embed_content .call_args_list
279+ assert len (calls ) == 2
280+ assert calls [0 ].kwargs ["contents" ] == ["first document text" , "second document text" ]
281+ assert calls [1 ].kwargs ["contents" ] == ["third document text" ]
282+
283+ @pytest .mark .asyncio
284+ async def test_embed_batch_async_passes_full_texts (self , monkeypatch ):
285+ monkeypatch .setenv ("GOOGLE_API_KEY" , "fake-api-key" )
286+ embedder = GoogleGenAIDocumentEmbedder (batch_size = 2 )
287+
288+ texts = ["first document text" , "second document text" , "third document text" ]
289+
290+ mock_embedding = MagicMock ()
291+ mock_embedding .values = [0.1 , 0.2 , 0.3 ]
292+
293+ mock_response = MagicMock ()
294+ mock_response .embeddings = [mock_embedding ]
295+
296+ embedder ._client = MagicMock ()
297+ embedder ._client .aio .models .embed_content = AsyncMock (return_value = mock_response )
298+
299+ await embedder ._embed_batch_async (texts , batch_size = 2 )
300+
301+ calls = embedder ._client .aio .models .embed_content .call_args_list
302+ assert len (calls ) == 2
303+ assert calls [0 ].kwargs ["contents" ] == ["first document text" , "second document text" ]
304+ assert calls [1 ].kwargs ["contents" ] == ["third document text" ]
305+
259306 @pytest .mark .skipif (
260307 not os .environ .get ("GOOGLE_API_KEY" , None ),
261308 reason = "Export an env var called GOOGLE_API_KEY containing the Google API key to run this test." ,
262309 )
263310 @pytest .mark .integration
264311 def test_run (self ):
312+ model = "gemini-embedding-001"
313+
265314 docs = [
266- Document (content = "I love cheese" , meta = {"topic" : "Cuisine" }),
267- Document (content = "A transformer is a deep learning architecture" , meta = {"topic" : "ML" }),
315+ Document (content = "The capybara is the largest rodent in the world and lives near rivers in South America." ),
316+ Document (content = "Dogs are domesticated mammals known for their loyalty and bond with humans." ),
317+ Document (content = "The tiger is the largest big cat, recognized by its orange coat with black stripes." ),
268318 ]
269319
270- model = "gemini-embedding-001"
271-
272- embedder = GoogleGenAIDocumentEmbedder (model = model , meta_fields_to_embed = ["topic" ], embedding_separator = " | " )
320+ embedder = GoogleGenAIDocumentEmbedder (model = model , config = {"task_type" : "RETRIEVAL_DOCUMENT" })
273321
274322 result = embedder .run (documents = docs )
275323 documents_with_embeddings = result ["documents" ]
@@ -283,6 +331,18 @@ def test_run(self):
283331
284332 assert result ["meta" ]["model" ] == model
285333
334+ text_embedder = GoogleGenAITextEmbedder (model = model , config = {"task_type" : "RETRIEVAL_QUERY" })
335+ query_embedding = text_embedder .run ("capybara" )["embedding" ]
336+ query_vec = np .array (query_embedding )
337+
338+ similarities = []
339+ for doc in documents_with_embeddings :
340+ doc_vec = np .array (doc .embedding )
341+ cosine_sim = np .dot (query_vec , doc_vec ) / (np .linalg .norm (query_vec ) * np .linalg .norm (doc_vec ))
342+ similarities .append (cosine_sim )
343+
344+ assert similarities [0 ] == max (similarities )
345+
286346 @pytest .mark .asyncio
287347 @pytest .mark .skipif (
288348 not os .environ .get ("GOOGLE_API_KEY" , None ),
0 commit comments