Skip to content

Commit 9ed8e7d

Browse files
feat: Google GenAI - add support for Vertex API (#2058)
* feat: Google GenAI - add support for Vertex API * fix * remove from sanitized_schema * introduce api parameter * improvements --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent 0645ae6 commit 9ed8e7d

10 files changed

Lines changed: 354 additions & 68 deletions

File tree

integrations/google_genai/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ all = 'pytest {args:tests}'
7070
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
7171

7272
types = """mypy -p haystack_integrations.components.generators.google_genai \
73-
-p haystack_integrations.components.embedders.google_genai {args}"""
73+
-p haystack_integrations.components.embedders.google_genai \
74+
-p haystack_integrations.components.common.google_genai {args}"""
7475

7576
[tool.mypy]
7677
install_types = true
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Literal, Optional
6+
7+
from google.genai import Client
8+
from haystack import logging
9+
from haystack.utils import Secret
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def _get_client(
15+
api_key: Secret,
16+
api: Literal["gemini", "vertex"],
17+
vertex_ai_project: Optional[str],
18+
vertex_ai_location: Optional[str],
19+
) -> Client:
20+
"""
21+
Internal utility function to get a Google GenAI client.
22+
23+
Supports:
24+
- Gemini Developer API (API Key Authentication)
25+
- Vertex AI (Application Default Credentials)
26+
- Vertex AI (API Key Authentication).
27+
28+
:param api_key: Google API key, defaults to the `GOOGLE_API_KEY` and `GEMINI_API_KEY` environment variables.
29+
:param api: Which API to use. Either "gemini" for the Gemini Developer API or "vertex" for Vertex AI.
30+
:param vertex_ai_project: Google Cloud project ID for Vertex AI. Required when using Vertex AI with
31+
Application Default Credentials.
32+
:param vertex_ai_location: Google Cloud location for Vertex AI (e.g., "us-central1", "europe-west1"). Required
33+
when using Vertex AI with Application Default Credentials.
34+
35+
:returns: A Google GenAI client.
36+
37+
:raises: ValueError if Gemini API is used without providing an API key or if Vertex AI is used without providing
38+
an API key or both vertex_ai_project and vertex_ai_location.
39+
"""
40+
41+
if api not in ["gemini", "vertex"]:
42+
msg = f"Invalid API: {api}. Must be either 'gemini' or 'vertex'."
43+
raise ValueError(msg)
44+
45+
resolved_api_key = api_key.resolve_value()
46+
47+
if api == "vertex":
48+
if not resolved_api_key and not (vertex_ai_project and vertex_ai_location):
49+
msg = (
50+
"To use Vertex AI, you must provide both vertex_ai_project and vertex_ai_location or export "
51+
"the GOOGLE_API_KEY or GEMINI_API_KEY environment variable."
52+
)
53+
raise ValueError(msg)
54+
55+
if vertex_ai_project and vertex_ai_location:
56+
logger.info("Using vertex_ai_project and vertex_ai_location for authentication.")
57+
return Client(vertexai=True, project=vertex_ai_project, location=vertex_ai_location)
58+
59+
logger.info(
60+
"No vertex_ai_project or vertex_ai_location provided for Vertex AI. Using the API key for authentication."
61+
)
62+
return Client(vertexai=True, api_key=resolved_api_key)
63+
64+
# Gemini API
65+
if not resolved_api_key:
66+
msg = "To use Gemini API, you must export the GOOGLE_API_KEY or GEMINI_API_KEY environment variable."
67+
raise ValueError(msg)
68+
69+
return Client(api_key=resolved_api_key)

integrations/google_genai/src/haystack_integrations/components/common/py.typed

Whitespace-only changes.

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
66

7-
from google import genai
87
from google.genai import types
98
from haystack import Document, component, default_from_dict, default_to_dict, logging
109
from haystack.utils import Secret, deserialize_secrets_inplace
1110
from more_itertools import batched
1211
from tqdm import tqdm
1312

13+
from haystack_integrations.components.common.google_genai.utils import _get_client
14+
1415
logger = logging.getLogger(__name__)
1516

1617

@@ -19,6 +20,39 @@ class GoogleGenAIDocumentEmbedder:
1920
"""
2021
Computes document embeddings using Google AI models.
2122
23+
### Authentication examples
24+
25+
**1. Gemini Developer API (API Key Authentication)**
26+
```python
27+
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder
28+
29+
# export the environment variable (GOOGLE_API_KEY or GEMINI_API_KEY)
30+
document_embedder = GoogleGenAIDocumentEmbedder(model="text-embedding-004")
31+
32+
**2. Vertex AI (Application Default Credentials)**
33+
```python
34+
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder
35+
36+
# Using Application Default Credentials (requires gcloud auth setup)
37+
document_embedder = GoogleGenAIDocumentEmbedder(
38+
api="vertex",
39+
vertex_ai_project="my-project",
40+
vertex_ai_location="us-central1",
41+
model="text-embedding-004"
42+
)
43+
```
44+
45+
**3. Vertex AI (API Key Authentication)**
46+
```python
47+
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder
48+
49+
# export the environment variable (GOOGLE_API_KEY or GEMINI_API_KEY)
50+
document_embedder = GoogleGenAIDocumentEmbedder(
51+
api="vertex",
52+
model="text-embedding-004"
53+
)
54+
```
55+
2256
### Usage example
2357
2458
```python
@@ -39,7 +73,10 @@ class GoogleGenAIDocumentEmbedder:
3973
def __init__(
4074
self,
4175
*,
42-
api_key: Secret = Secret.from_env_var(["GOOGLE_API_KEY", "GEMINI_API_KEY"]),
76+
api_key: Secret = Secret.from_env_var(["GOOGLE_API_KEY", "GEMINI_API_KEY"], strict=False),
77+
api: Literal["gemini", "vertex"] = "gemini",
78+
vertex_ai_project: Optional[str] = None,
79+
vertex_ai_location: Optional[str] = None,
4380
model: str = "text-embedding-004",
4481
prefix: str = "",
4582
suffix: str = "",
@@ -52,10 +89,15 @@ def __init__(
5289
"""
5390
Creates an GoogleGenAIDocumentEmbedder component.
5491
55-
:param api_key:
56-
The Google API key.
57-
You can set it with the environment variable `GOOGLE_API_KEY` or `GEMINI_API_KEY`, or pass it via
58-
this parameter during initialization.
92+
:param api_key: Google API key, defaults to the `GOOGLE_API_KEY` and `GEMINI_API_KEY` environment variables.
93+
Not needed if using Vertex AI with Application Default Credentials.
94+
Go to https://aistudio.google.com/app/apikey for a Gemini API key.
95+
Go to https://cloud.google.com/vertex-ai/generative-ai/docs/start/api-keys for a Vertex AI API key.
96+
:param api: Which API to use. Either "gemini" for the Gemini Developer API or "vertex" for Vertex AI.
97+
:param vertex_ai_project: Google Cloud project ID for Vertex AI. Required when using Vertex AI with
98+
Application Default Credentials.
99+
:param vertex_ai_location: Google Cloud location for Vertex AI (e.g., "us-central1", "europe-west1").
100+
Required when using Vertex AI with Application Default Credentials.
59101
:param model:
60102
The name of the model to use for calculating embeddings.
61103
The default model is `text-embedding-ada-002`.
@@ -77,16 +119,25 @@ def __init__(
77119
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
78120
"""
79121
self._api_key = api_key
122+
self._api = api
123+
self._vertex_ai_project = vertex_ai_project
124+
self._vertex_ai_location = vertex_ai_location
80125
self._model = model
81126
self._prefix = prefix
82127
self._suffix = suffix
83128
self._batch_size = batch_size
84129
self._progress_bar = progress_bar
85130
self._meta_fields_to_embed = meta_fields_to_embed or []
86131
self._embedding_separator = embedding_separator
87-
self._client = genai.Client(api_key=api_key.resolve_value())
88132
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
89133

134+
self._client = _get_client(
135+
api_key=api_key,
136+
api=api,
137+
vertex_ai_project=vertex_ai_project,
138+
vertex_ai_location=vertex_ai_location,
139+
)
140+
90141
def to_dict(self) -> Dict[str, Any]:
91142
"""
92143
Serializes the component to a dictionary.
@@ -104,6 +155,9 @@ def to_dict(self) -> Dict[str, Any]:
104155
meta_fields_to_embed=self._meta_fields_to_embed,
105156
embedding_separator=self._embedding_separator,
106157
api_key=self._api_key.to_dict(),
158+
api=self._api,
159+
vertex_ai_project=self._vertex_ai_project,
160+
vertex_ai_location=self._vertex_ai_location,
107161
config=self._config,
108162
)
109163

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, List, Literal, Optional, Union
66

7-
from google import genai
87
from google.genai import types
98
from haystack import component, default_from_dict, default_to_dict, logging
109
from haystack.utils import Secret, deserialize_secrets_inplace
1110

11+
from haystack_integrations.components.common.google_genai.utils import _get_client
12+
1213
logger = logging.getLogger(__name__)
1314

1415

@@ -19,6 +20,40 @@ class GoogleGenAITextEmbedder:
1920
2021
You can use it to embed user query and send it to an embedding Retriever.
2122
23+
### Authentication examples
24+
25+
**1. Gemini Developer API (API Key Authentication)**
26+
```python
27+
from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
28+
29+
# export the environment variable (GOOGLE_API_KEY or GEMINI_API_KEY)
30+
text_embedder = GoogleGenAITextEmbedder(model="text-embedding-004")
31+
32+
**2. Vertex AI (Application Default Credentials)**
33+
```python
34+
from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
35+
36+
# Using Application Default Credentials (requires gcloud auth setup)
37+
text_embedder = GoogleGenAITextEmbedder(
38+
api="vertex",
39+
vertex_ai_project="my-project",
40+
vertex_ai_location="us-central1",
41+
model="text-embedding-004"
42+
)
43+
```
44+
45+
**3. Vertex AI (API Key Authentication)**
46+
```python
47+
from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
48+
49+
# export the environment variable (GOOGLE_API_KEY or GEMINI_API_KEY)
50+
text_embedder = GoogleGenAITextEmbedder(
51+
api="vertex",
52+
model="text-embedding-004"
53+
)
54+
```
55+
56+
2257
### Usage example
2358
2459
```python
@@ -39,7 +74,10 @@ class GoogleGenAITextEmbedder:
3974
def __init__(
4075
self,
4176
*,
42-
api_key: Secret = Secret.from_env_var(["GOOGLE_API_KEY", "GEMINI_API_KEY"]),
77+
api_key: Secret = Secret.from_env_var(["GOOGLE_API_KEY", "GEMINI_API_KEY"], strict=False),
78+
api: Literal["gemini", "vertex"] = "gemini",
79+
vertex_ai_project: Optional[str] = None,
80+
vertex_ai_location: Optional[str] = None,
4381
model: str = "text-embedding-004",
4482
prefix: str = "",
4583
suffix: str = "",
@@ -48,10 +86,15 @@ def __init__(
4886
"""
4987
Creates an GoogleGenAITextEmbedder component.
5088
51-
:param api_key:
52-
The Google API key.
53-
You can set it with the environment variable `GOOGLE_API_KEY` or `GEMINI_API_KEY`, or pass it via
54-
this parameter during initialization.
89+
:param api_key: Google API key, defaults to the `GOOGLE_API_KEY` and `GEMINI_API_KEY` environment variables.
90+
Not needed if using Vertex AI with Application Default Credentials.
91+
Go to https://aistudio.google.com/app/apikey for a Gemini API key.
92+
Go to https://cloud.google.com/vertex-ai/generative-ai/docs/start/api-keys for a Vertex AI API key.
93+
:param api: Which API to use. Either "gemini" for the Gemini Developer API or "vertex" for Vertex AI.
94+
:param vertex_ai_project: Google Cloud project ID for Vertex AI. Required when using Vertex AI with
95+
Application Default Credentials.
96+
:param vertex_ai_location: Google Cloud location for Vertex AI (e.g., "us-central1", "europe-west1").
97+
Required when using Vertex AI with Application Default Credentials.
5598
:param model:
5699
The name of the model to use for calculating embeddings.
57100
The default model is `text-embedding-004`.
@@ -66,11 +109,19 @@ def __init__(
66109
"""
67110

68111
self._api_key = api_key
112+
self._api = api
113+
self._vertex_ai_project = vertex_ai_project
114+
self._vertex_ai_location = vertex_ai_location
69115
self._model_name = model
70116
self._prefix = prefix
71117
self._suffix = suffix
72118
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
73-
self._client = genai.Client(api_key=api_key.resolve_value())
119+
self._client = _get_client(
120+
api_key=api_key,
121+
api=api,
122+
vertex_ai_project=vertex_ai_project,
123+
vertex_ai_location=vertex_ai_location,
124+
)
74125

75126
def to_dict(self) -> Dict[str, Any]:
76127
"""
@@ -83,6 +134,9 @@ def to_dict(self) -> Dict[str, Any]:
83134
self,
84135
model=self._model_name,
85136
api_key=self._api_key.to_dict(),
137+
api=self._api,
138+
vertex_ai_project=self._vertex_ai_project,
139+
vertex_ai_location=self._vertex_ai_location,
86140
prefix=self._prefix,
87141
suffix=self._suffix,
88142
config=self._config,

0 commit comments

Comments
 (0)