Skip to content

Commit c8f0ca4

Browse files
feat: add run_async for VertexAIGeminiChatGenerator (#1574)
* Added async calls for VertexAIGeminiChatGenerator * Linter and typing fixes * Replace Iterable with AsyncIterable type * pass tool_config to the send_message_async --------- Co-authored-by: Julian Risch <julianrisch@gmx.de>
1 parent f4c6b15 commit c8f0ca4

3 files changed

Lines changed: 280 additions & 12 deletions

File tree

integrations/google_vertex/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ installer = "uv"
4646
dependencies = [
4747
"coverage[toml]>=6.5",
4848
"pytest",
49+
"pytest-asyncio",
4950
"pytest-rerunfailures",
5051
"haystack-pydoc-tools",
5152
]

integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py

Lines changed: 122 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import json
2-
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
2+
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
33

44
from haystack import logging
55
from haystack.core.component import component
66
from haystack.core.serialization import default_from_dict, default_to_dict
7-
from haystack.dataclasses import StreamingChunk
7+
from haystack.dataclasses import AsyncStreamingCallbackT, StreamingCallbackT, StreamingChunk, select_streaming_callback
88
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall
99
from haystack.tools import Tool, _check_duplicate_tool_names
1010
from haystack.utils import deserialize_callable, serialize_callable
@@ -150,7 +150,7 @@ def __init__(
150150
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
151151
tools: Optional[List[Tool]] = None,
152152
tool_config: Optional[ToolConfig] = None,
153-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
153+
streaming_callback: Optional[StreamingCallbackT] = None,
154154
):
155155
"""
156156
`VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models.
@@ -300,7 +300,7 @@ def _convert_to_vertex_tools(tools: List[Tool]) -> List[VertexTool]:
300300
def run(
301301
self,
302302
messages: List[ChatMessage],
303-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
303+
streaming_callback: Optional[StreamingCallbackT] = None,
304304
*,
305305
tools: Optional[List[Tool]] = None,
306306
):
@@ -355,6 +355,69 @@ def run(
355355

356356
return {"replies": replies}
357357

358+
@component.output_types(replies=List[ChatMessage])
359+
async def run_async(
360+
self,
361+
messages: List[ChatMessage],
362+
streaming_callback: Optional[StreamingCallbackT] = None,
363+
*,
364+
tools: Optional[List[Tool]] = None,
365+
):
366+
"""
367+
Async version of the run method. Generates text based on the provided messages.
368+
:param messages:
369+
A list of `ChatMessage` instances, representing the input messages.
370+
:param streaming_callback:
371+
A callback function that is called when a new token is received from the stream.
372+
:param tools:
373+
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
374+
during component initialization.
375+
:returns:
376+
A dictionary containing the following key:
377+
- `replies`: A list containing the generated responses as `ChatMessage` instances.
378+
"""
379+
streaming_callback = select_streaming_callback(
380+
self._streaming_callback, streaming_callback, requires_async=True
381+
)
382+
383+
tools = tools or self._tools
384+
_check_duplicate_tool_names(tools)
385+
google_tools = self._convert_to_vertex_tools(tools) if tools else None
386+
387+
if messages[0].is_from(ChatRole.SYSTEM):
388+
self._model._system_instruction = Part.from_text(messages[0].text)
389+
messages = messages[1:]
390+
391+
google_messages = [_convert_chatmessage_to_google_content(m) for m in messages]
392+
393+
session = self._model.start_chat(history=google_messages[:-1])
394+
395+
candidate_count = 1
396+
if self._generation_config:
397+
config_dict = self._generation_config_to_dict(self._generation_config)
398+
candidate_count = config_dict.get("candidate_count", 1)
399+
400+
if streaming_callback and candidate_count > 1:
401+
msg = "Streaming is not supported with multiple candidates. Set candidate_count to 1."
402+
raise ValueError(msg)
403+
404+
res = await session.send_message_async(
405+
content=google_messages[-1],
406+
generation_config=self._generation_config,
407+
safety_settings=self._safety_settings,
408+
stream=streaming_callback is not None,
409+
tools=google_tools,
410+
tool_config=self._tool_config,
411+
)
412+
413+
replies = (
414+
await self._stream_response_and_convert_to_messages_async(res, streaming_callback)
415+
if streaming_callback
416+
else self._convert_response_to_messages(res)
417+
)
418+
419+
return {"replies": replies}
420+
358421
@staticmethod
359422
def _convert_response_to_messages(response_body: GenerationResponse) -> List[ChatMessage]:
360423
"""
@@ -395,7 +458,7 @@ def _convert_response_to_messages(response_body: GenerationResponse) -> List[Cha
395458
return replies
396459

397460
def _stream_response_and_convert_to_messages(
398-
self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None]
461+
self, stream: Iterable[GenerationResponse], streaming_callback: StreamingCallbackT
399462
) -> List[ChatMessage]:
400463
"""
401464
Streams the Google Vertex AI response and converts it to a list of `ChatMessage` instances.
@@ -446,3 +509,57 @@ def _stream_response_and_convert_to_messages(
446509
meta["usage"] = openai_usage
447510

448511
return [ChatMessage.from_assistant(text=text or None, meta=meta, tool_calls=tool_calls)]
512+
513+
@staticmethod
514+
async def _stream_response_and_convert_to_messages_async(
515+
stream: AsyncIterable[GenerationResponse], streaming_callback: AsyncStreamingCallbackT
516+
) -> List[ChatMessage]:
517+
"""
518+
Streams the Google Vertex AI response and converts it to a list of `ChatMessage` instances.
519+
520+
:param stream: The streaming response from the Google AI request.
521+
:param streaming_callback: The handler for the streaming response.
522+
:returns: List of `ChatMessage` instances.
523+
"""
524+
525+
text = ""
526+
tool_calls = []
527+
chunk_dict = {}
528+
529+
async for chunk in stream:
530+
content_to_stream = ""
531+
chunk_dict = chunk.to_dict()
532+
533+
# Only one candidate is supported with streaming
534+
candidate = chunk_dict["candidates"][0]
535+
536+
for part in candidate["content"]["parts"]:
537+
if new_text := part.get("text"):
538+
content_to_stream += new_text
539+
text += new_text
540+
elif new_function_call := part.get("function_call"):
541+
content_to_stream += json.dumps(dict(new_function_call))
542+
tool_calls.append(
543+
ToolCall(
544+
tool_name=new_function_call["name"],
545+
arguments=new_function_call["args"],
546+
)
547+
)
548+
549+
await streaming_callback(StreamingChunk(content=content_to_stream, meta=chunk_dict))
550+
551+
# store the last chunk metadata
552+
meta = chunk_dict
553+
554+
# format the usage metadata to be compatible with OpenAI
555+
usage_metadata = meta.pop("usage_metadata", {})
556+
557+
openai_usage = {
558+
"prompt_tokens": usage_metadata.get("prompt_token_count", 0),
559+
"completion_tokens": usage_metadata.get("candidates_token_count", 0),
560+
"total_tokens": usage_metadata.get("total_token_count", 0),
561+
}
562+
563+
meta["usage"] = openai_usage
564+
565+
return [ChatMessage.from_assistant(text=text or None, meta=meta, tool_calls=tool_calls)]

integrations/google_vertex/tests/chat/test_gemini.py

Lines changed: 157 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from typing import Annotated, Literal
3-
from unittest.mock import MagicMock, Mock, patch
3+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
44

55
import pytest
66
from haystack import Pipeline
@@ -224,8 +224,7 @@ def test_from_dict(self, _mock_vertexai_init, _mock_generative_model):
224224
gemini = VertexAIGeminiChatGenerator.from_dict(
225225
{
226226
"type": (
227-
"haystack_integrations.components.generators.google_vertex.chat.gemini."
228-
"VertexAIGeminiChatGenerator"
227+
"haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator"
229228
),
230229
"init_parameters": {
231230
"project_id": None,
@@ -253,8 +252,7 @@ def test_from_dict_with_param(self, _mock_vertexai_init, _mock_generative_model)
253252
gemini = VertexAIGeminiChatGenerator.from_dict(
254253
{
255254
"type": (
256-
"haystack_integrations.components.generators.google_vertex.chat.gemini."
257-
"VertexAIGeminiChatGenerator"
255+
"haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator"
258256
),
259257
"init_parameters": {
260258
"project_id": "TestID123",
@@ -513,6 +511,159 @@ def streaming_callback(chunk: StreamingChunk) -> None:
513511
assert reply.tool_calls[1].arguments == {"city": "Munich"}
514512
assert reply.meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
515513

514+
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
515+
@pytest.mark.asyncio
516+
async def test_run_async(self, mock_generative_model):
517+
mock_model = Mock()
518+
mock_candidate = MagicMock(
519+
content=Content(parts=[Part.from_text("This is a generated response.")], role="model")
520+
)
521+
mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate])
522+
523+
mock_model.send_message_async = AsyncMock(return_value=mock_response)
524+
mock_model.start_chat.return_value = mock_model
525+
mock_generative_model.return_value = mock_model
526+
527+
messages = [
528+
ChatMessage.from_system("You are a helpful assistant"),
529+
ChatMessage.from_user("What's the capital of France?"),
530+
]
531+
gemini = VertexAIGeminiChatGenerator()
532+
response = await gemini.run_async(messages=messages)
533+
534+
mock_model.send_message_async.assert_called_once()
535+
assert "replies" in response
536+
reply = response["replies"][0]
537+
assert reply.role == ChatRole.ASSISTANT
538+
assert reply.text == "This is a generated response."
539+
540+
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
541+
@pytest.mark.asyncio
542+
async def test_run_with_tools_async(self, mock_generative_model, tools):
543+
mock_model = Mock()
544+
mock_candidate = MagicMock(
545+
content=Content(
546+
parts=[
547+
Part.from_dict(
548+
{"function_call": {"name": "get_current_weather", "args": {"city": "Paris", "unit": "Celsius"}}}
549+
),
550+
],
551+
role="model",
552+
)
553+
)
554+
mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate])
555+
556+
mock_model.send_message_async = AsyncMock(return_value=mock_response)
557+
mock_model.start_chat.return_value = mock_model
558+
mock_generative_model.return_value = mock_model
559+
560+
messages = [
561+
ChatMessage.from_user("What's the weather in Paris?"),
562+
]
563+
564+
gemini = VertexAIGeminiChatGenerator(tools=tools)
565+
response = await gemini.run_async(messages=messages)
566+
567+
mock_model.send_message_async.assert_called_once()
568+
call_kwargs = mock_model.send_message_async.call_args.kwargs
569+
assert "tools" in call_kwargs
570+
571+
assert "replies" in response
572+
reply = response["replies"][0]
573+
assert reply.role == ChatRole.ASSISTANT
574+
assert not reply.texts
575+
assert not reply.text
576+
assert len(reply.tool_calls) == 1
577+
assert reply.tool_calls[0].tool_name == "get_current_weather"
578+
assert reply.tool_calls[0].arguments == {"city": "Paris", "unit": "Celsius"}
579+
580+
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
581+
@pytest.mark.asyncio
582+
async def test_run_with_muliple_tools_and_streaming_async(self, mock_generative_model, tools):
583+
"""
584+
Test that the generator can handle multiple tools and streaming.
585+
Note: this test case is made up because in practice I have always seen multiple function calls in a single
586+
streaming chunk.
587+
"""
588+
589+
def population(city: Annotated[str, "the city for which to get the population, e.g. 'Munich'"] = "Munich"):
590+
"""A simple function to get the population for a location."""
591+
return f"Population of {city}: 1,000,000"
592+
593+
multiple_tools = [tools[0], create_tool_from_function(population)]
594+
595+
mock_model = Mock()
596+
597+
mock_responses = [
598+
MagicMock(
599+
spec=GenerationResponse,
600+
to_dict=lambda: {
601+
"candidates": [
602+
{
603+
"content": {
604+
"parts": [
605+
{
606+
"function_call": {
607+
"name": "get_current_weather",
608+
"args": {"city": "Munich", "unit": "Farenheit"},
609+
}
610+
}
611+
]
612+
}
613+
}
614+
]
615+
},
616+
),
617+
MagicMock(
618+
spec=GenerationResponse,
619+
to_dict=lambda: {
620+
"candidates": [
621+
{"content": {"parts": [{"function_call": {"name": "population", "args": {"city": "Munich"}}}]}}
622+
],
623+
"usage_metadata": {"prompt_token_count": 10, "candidates_token_count": 5, "total_token_count": 15},
624+
},
625+
),
626+
]
627+
628+
async def async_response_generator():
629+
for response in mock_responses:
630+
yield response
631+
632+
mock_model.send_message_async = AsyncMock(return_value=async_response_generator())
633+
mock_model.start_chat.return_value = mock_model
634+
mock_generative_model.return_value = mock_model
635+
636+
received_chunks = []
637+
638+
async def async_streaming_callback(chunk: StreamingChunk) -> None:
639+
received_chunks.append(chunk)
640+
641+
messages = [
642+
ChatMessage.from_user("What's the weather in Munich (in Farenheit) and how many people live there?"),
643+
]
644+
645+
gemini = VertexAIGeminiChatGenerator(tools=multiple_tools, streaming_callback=async_streaming_callback)
646+
response = await gemini.run_async(messages=messages)
647+
648+
assert len(received_chunks) == 2
649+
assert json.loads(received_chunks[0].content) == {
650+
"name": "get_current_weather",
651+
"args": {"city": "Munich", "unit": "Farenheit"},
652+
}
653+
assert json.loads(received_chunks[1].content) == {"name": "population", "args": {"city": "Munich"}}
654+
655+
assert "replies" in response
656+
reply = response["replies"][0]
657+
assert reply.role == ChatRole.ASSISTANT
658+
assert not reply.texts
659+
assert not reply.text
660+
assert len(reply.tool_calls) == 2
661+
assert reply.tool_calls[0].tool_name == "get_current_weather"
662+
assert reply.tool_calls[0].arguments == {"city": "Munich", "unit": "Farenheit"}
663+
assert reply.tool_calls[1].tool_name == "population"
664+
assert reply.tool_calls[1].arguments == {"city": "Munich"}
665+
assert reply.meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
666+
516667
def test_serde_in_pipeline(self):
517668
tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
518669

@@ -538,8 +689,7 @@ def test_serde_in_pipeline(self):
538689
"components": {
539690
"generator": {
540691
"type": (
541-
"haystack_integrations.components.generators.google_vertex.chat.gemini."
542-
"VertexAIGeminiChatGenerator"
692+
"haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator"
543693
),
544694
"init_parameters": {
545695
"project_id": "TestID123",

0 commit comments

Comments
 (0)