Skip to content

Commit 30fdc8e

Browse files
committed
Merge branch 'main' into hatch-commands-exp
2 parents 1fb6960 + 7c5e03f commit 30fdc8e

10 files changed

Lines changed: 580 additions & 7 deletions

File tree

.github/workflows/nvidia.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ jobs:
5757
if: matrix.python-version == '3.9' && runner.os == 'Linux'
5858
run: hatch run fmt-check && hatch run lint:typing
5959

60-
- name: Run tests
61-
run: hatch run test:cov-retry
62-
6360
- name: Generate docs
6461
if: matrix.python-version == '3.9' && runner.os == 'Linux'
6562
run: hatch run docs
6663

64+
- name: Run tests
65+
run: hatch run test:cov-retry
66+
6767
- name: Run unit tests with lowest direct dependencies
6868
run: |
6969
hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt

integrations/google_genai/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Changelog
22

3+
## [integrations/google_genai-v1.0.1] - 2025-06-05
4+
5+
### 🌀 Miscellaneous
6+
7+
- Style: Update to linting to allow function calls in default arguments (#1899)
8+
- Add examples, set safety_settings
9+
- Add print in examples
10+
311
## [integrations/google_genai-v1.0.0] - 2025-06-02
412

513
### 🚀 Features
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# To run this example, you will need to
2+
# 1) set `GOOGLE_API_KEY` environment variable
3+
# 2) install the google_genai_haystack integration: pip install google-genai-haystack
4+
# Note: if you change the model, update the model-specific inference parameters.
5+
6+
7+
from haystack.dataclasses import ChatMessage
8+
9+
from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator
10+
11+
generator = GoogleGenAIChatGenerator(
12+
model="gemini-2.0-flash",
13+
# model-specific inference parameters
14+
generation_kwargs={
15+
"temperature": 0.7,
16+
},
17+
)
18+
19+
system_prompt = """
20+
You are a helpful assistant that helps users learn more about Google Cloud services.
21+
Your audience is engineers with a decent technical background.
22+
Be very concise and specific in your answers, keeping them short.
23+
You may use technical terms, jargon, and abbreviations that are common among practitioners.
24+
"""
25+
26+
messages = [
27+
ChatMessage.from_system(system_prompt),
28+
ChatMessage.from_user("Which service should I use to train custom Machine Learning models?"),
29+
]
30+
31+
results = generator.run(messages)
32+
print(results["replies"][0].text)

integrations/google_genai/pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ignore = [
135135
"ARG001",
136136
"ARG002",
137137
"ARG005",
138+
# Allow function call argument defaults e.g. `Secret.from_env_var`
139+
"B008",
138140
]
139141
unfixable = [
140142
# Don't touch unused imports
@@ -150,6 +152,8 @@ ban-relative-imports = "parents"
150152
[tool.ruff.lint.per-file-ignores]
151153
# Tests can use magic values, assertions, and relative imports
152154
"tests/**/*" = ["PLR2004", "S101", "TID252"]
155+
# Examples can use print statements
156+
"examples/**/*" = ["T201"]
153157

154158
[tool.coverage.run]
155159
source = ["haystack_integrations"]
@@ -180,4 +184,4 @@ markers = [
180184
"generators: generators tests",
181185
]
182186
log_cli = true
183-
asyncio_mode = "auto"
187+
asyncio_mode = "auto"

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def weather_function(city: str):
230230
def __init__(
231231
self,
232232
*,
233-
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008
233+
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"),
234234
model: str = "gemini-2.0-flash",
235235
generation_kwargs: Optional[Dict[str, Any]] = None,
236236
safety_settings: Optional[List[Dict[str, Any]]] = None,
@@ -509,6 +509,10 @@ def run(
509509
if system_instruction:
510510
config_params["system_instruction"] = system_instruction
511511

512+
# Add safety settings if provided
513+
if safety_settings:
514+
config_params["safety_settings"] = safety_settings
515+
512516
# Add tools if provided
513517
if tools:
514518
config_params["tools"] = _convert_tools_to_google_genai_format(tools)
@@ -593,6 +597,10 @@ async def run_async(
593597
if system_instruction:
594598
config_params["system_instruction"] = system_instruction
595599

600+
# Add safety settings if provided
601+
if safety_settings:
602+
config_params["safety_settings"] = safety_settings
603+
596604
# Add tools if provided
597605
if tools:
598606
config_params["tools"] = _convert_tools_to_google_genai_format(tools)

integrations/nvidia/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## [integrations/nvidia-v0.2.0] - 2025-06-05
4+
5+
### 🚀 Features
6+
7+
- Add NvidiaChatGenerator based on OpenAIChatGenerator (#1776)
8+
9+
310
## [integrations/nvidia-v0.1.8] - 2025-05-28
411

512
### 🌀 Miscellaneous

integrations/nvidia/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai", "requests>=2.25.0", "tqdm>=4.21.0"]
26+
dependencies = ["haystack-ai>=2.13.0", "requests>=2.25.0", "tqdm>=4.21.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
@@ -168,6 +168,7 @@ module = [
168168
"pytest.*",
169169
"numpy.*",
170170
"requests_mock.*",
171+
"openai.*",
171172
"pydantic.*",
172173
]
173174
ignore_missing_imports = true

integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from .chat.chat_generator import NvidiaChatGenerator
56
from .generator import NvidiaGenerator
67

7-
__all__ = ["NvidiaGenerator"]
8+
__all__ = ["NvidiaChatGenerator", "NvidiaGenerator"]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import os
6+
from typing import Any, Dict, List, Optional, Union
7+
8+
from haystack import component, default_to_dict, logging
9+
from haystack.components.generators.chat import OpenAIChatGenerator
10+
from haystack.dataclasses import StreamingCallbackT
11+
from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
12+
from haystack.utils import serialize_callable
13+
from haystack.utils.auth import Secret
14+
15+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@component
21+
class NvidiaChatGenerator(OpenAIChatGenerator):
22+
"""
23+
Enables text generation using NVIDIA generative models.
24+
For supported models, see [NVIDIA Docs](https://build.nvidia.com/models).
25+
26+
Users can pass any text generation parameters valid for the NVIDIA Chat Completion API
27+
directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
28+
parameter in `run` method.
29+
30+
This component uses the ChatMessage format for structuring both input and output,
31+
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
32+
Details on the ChatMessage format can be found in the
33+
[Haystack docs](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
34+
35+
For more details on the parameters supported by the NVIDIA API, refer to the
36+
[NVIDIA Docs](https://build.nvidia.com/models).
37+
38+
Usage example:
39+
```python
40+
from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator
41+
from haystack.dataclasses import ChatMessage
42+
43+
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
44+
45+
client = NvidiaChatGenerator()
46+
response = client.run(messages)
47+
print(response)
48+
```
49+
"""
50+
51+
def __init__(
52+
self,
53+
*,
54+
api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"),
55+
model: str = "meta/llama-3.1-8b-instruct",
56+
streaming_callback: Optional[StreamingCallbackT] = None,
57+
api_base_url: Optional[str] = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL),
58+
generation_kwargs: Optional[Dict[str, Any]] = None,
59+
tools: Optional[Union[List[Tool], Toolset]] = None,
60+
timeout: Optional[float] = None,
61+
max_retries: Optional[int] = None,
62+
http_client_kwargs: Optional[Dict[str, Any]] = None,
63+
):
64+
"""
65+
Creates an instance of NvidiaChatGenerator.
66+
67+
:param api_key:
68+
The NVIDIA API key.
69+
:param model:
70+
The name of the NVIDIA chat completion model to use.
71+
:param streaming_callback:
72+
A callback function that is called when a new token is received from the stream.
73+
The callback function accepts StreamingChunk as an argument.
74+
:param api_base_url:
75+
The NVIDIA API Base url.
76+
:param generation_kwargs:
77+
Other parameters to use for the model. These parameters are all sent directly to
78+
the NVIDIA API endpoint. See [NVIDIA API docs](https://docs.nvcf.nvidia.com/ai/generative-models/)
79+
for more details.
80+
Some of the supported parameters:
81+
- `max_tokens`: The maximum number of tokens the output text can have.
82+
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
83+
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
84+
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
85+
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
86+
comprising the top 10% probability mass are considered.
87+
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
88+
events as they become available, with the stream terminated by a data: [DONE] message.
89+
:param tools:
90+
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
91+
list of `Tool` objects or a `Toolset` instance.
92+
:param timeout:
93+
The timeout for the NVIDIA API call.
94+
:param max_retries:
95+
Maximum number of retries to contact NVIDIA after an internal error.
96+
If not set, it defaults to either the `NVIDIA_MAX_RETRIES` environment variable, or set to 5.
97+
:param http_client_kwargs:
98+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
99+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
100+
"""
101+
super(NvidiaChatGenerator, self).__init__( # noqa: UP008
102+
api_key=api_key,
103+
model=model,
104+
streaming_callback=streaming_callback,
105+
api_base_url=api_base_url,
106+
generation_kwargs=generation_kwargs,
107+
tools=tools,
108+
timeout=timeout,
109+
max_retries=max_retries,
110+
http_client_kwargs=http_client_kwargs,
111+
)
112+
113+
def to_dict(self) -> Dict[str, Any]:
114+
"""
115+
Serialize this component to a dictionary.
116+
117+
:returns:
118+
The serialized component as a dictionary.
119+
"""
120+
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
121+
122+
return default_to_dict(
123+
self,
124+
model=self.model,
125+
streaming_callback=callback_name,
126+
api_base_url=self.api_base_url,
127+
generation_kwargs=self.generation_kwargs,
128+
api_key=self.api_key.to_dict(),
129+
tools=serialize_tools_or_toolset(self.tools),
130+
timeout=self.timeout,
131+
max_retries=self.max_retries,
132+
http_client_kwargs=self.http_client_kwargs,
133+
)

0 commit comments

Comments
 (0)