|
1 | 1 | import json |
2 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Union |
| 2 | +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union |
3 | 3 |
|
4 | 4 | from haystack import logging |
5 | 5 | from haystack.core.component import component |
6 | 6 | from haystack.core.serialization import default_from_dict, default_to_dict |
7 | | -from haystack.dataclasses import StreamingChunk |
| 7 | +from haystack.dataclasses import AsyncStreamingCallbackT, StreamingCallbackT, StreamingChunk, select_streaming_callback |
8 | 8 | from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall |
9 | 9 | from haystack.tools import Tool, _check_duplicate_tool_names |
10 | 10 | from haystack.utils import deserialize_callable, serialize_callable |
@@ -150,7 +150,7 @@ def __init__( |
150 | 150 | safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, |
151 | 151 | tools: Optional[List[Tool]] = None, |
152 | 152 | tool_config: Optional[ToolConfig] = None, |
153 | | - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, |
| 153 | + streaming_callback: Optional[StreamingCallbackT] = None, |
154 | 154 | ): |
155 | 155 | """ |
156 | 156 | `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. |
@@ -300,7 +300,7 @@ def _convert_to_vertex_tools(tools: List[Tool]) -> List[VertexTool]: |
300 | 300 | def run( |
301 | 301 | self, |
302 | 302 | messages: List[ChatMessage], |
303 | | - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, |
| 303 | + streaming_callback: Optional[StreamingCallbackT] = None, |
304 | 304 | *, |
305 | 305 | tools: Optional[List[Tool]] = None, |
306 | 306 | ): |
@@ -355,6 +355,69 @@ def run( |
355 | 355 |
|
356 | 356 | return {"replies": replies} |
357 | 357 |
|
| 358 | + @component.output_types(replies=List[ChatMessage]) |
| 359 | + async def run_async( |
| 360 | + self, |
| 361 | + messages: List[ChatMessage], |
| 362 | + streaming_callback: Optional[StreamingCallbackT] = None, |
| 363 | + *, |
| 364 | + tools: Optional[List[Tool]] = None, |
| 365 | + ): |
| 366 | + """ |
| 367 | + Async version of the run method. Generates text based on the provided messages. |
| 368 | + :param messages: |
| 369 | + A list of `ChatMessage` instances, representing the input messages. |
| 370 | + :param streaming_callback: |
| 371 | + A callback function that is called when a new token is received from the stream. |
| 372 | + :param tools: |
| 373 | + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set |
| 374 | + during component initialization. |
| 375 | + :returns: |
| 376 | + A dictionary containing the following key: |
| 377 | + - `replies`: A list containing the generated responses as `ChatMessage` instances. |
| 378 | + """ |
| 379 | + streaming_callback = select_streaming_callback( |
| 380 | + self._streaming_callback, streaming_callback, requires_async=True |
| 381 | + ) |
| 382 | + |
| 383 | + tools = tools or self._tools |
| 384 | + _check_duplicate_tool_names(tools) |
| 385 | + google_tools = self._convert_to_vertex_tools(tools) if tools else None |
| 386 | + |
| 387 | + if messages[0].is_from(ChatRole.SYSTEM): |
| 388 | + self._model._system_instruction = Part.from_text(messages[0].text) |
| 389 | + messages = messages[1:] |
| 390 | + |
| 391 | + google_messages = [_convert_chatmessage_to_google_content(m) for m in messages] |
| 392 | + |
| 393 | + session = self._model.start_chat(history=google_messages[:-1]) |
| 394 | + |
| 395 | + candidate_count = 1 |
| 396 | + if self._generation_config: |
| 397 | + config_dict = self._generation_config_to_dict(self._generation_config) |
| 398 | + candidate_count = config_dict.get("candidate_count", 1) |
| 399 | + |
| 400 | + if streaming_callback and candidate_count > 1: |
| 401 | + msg = "Streaming is not supported with multiple candidates. Set candidate_count to 1." |
| 402 | + raise ValueError(msg) |
| 403 | + |
| 404 | + res = await session.send_message_async( |
| 405 | + content=google_messages[-1], |
| 406 | + generation_config=self._generation_config, |
| 407 | + safety_settings=self._safety_settings, |
| 408 | + stream=streaming_callback is not None, |
| 409 | + tools=google_tools, |
| 410 | + tool_config=self._tool_config, |
| 411 | + ) |
| 412 | + |
| 413 | + replies = ( |
| 414 | + await self._stream_response_and_convert_to_messages_async(res, streaming_callback) |
| 415 | + if streaming_callback |
| 416 | + else self._convert_response_to_messages(res) |
| 417 | + ) |
| 418 | + |
| 419 | + return {"replies": replies} |
| 420 | + |
358 | 421 | @staticmethod |
359 | 422 | def _convert_response_to_messages(response_body: GenerationResponse) -> List[ChatMessage]: |
360 | 423 | """ |
@@ -395,7 +458,7 @@ def _convert_response_to_messages(response_body: GenerationResponse) -> List[Cha |
395 | 458 | return replies |
396 | 459 |
|
397 | 460 | def _stream_response_and_convert_to_messages( |
398 | | - self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] |
| 461 | + self, stream: Iterable[GenerationResponse], streaming_callback: StreamingCallbackT |
399 | 462 | ) -> List[ChatMessage]: |
400 | 463 | """ |
401 | 464 | Streams the Google Vertex AI response and converts it to a list of `ChatMessage` instances. |
@@ -446,3 +509,57 @@ def _stream_response_and_convert_to_messages( |
446 | 509 | meta["usage"] = openai_usage |
447 | 510 |
|
448 | 511 | return [ChatMessage.from_assistant(text=text or None, meta=meta, tool_calls=tool_calls)] |
| 512 | + |
| 513 | + @staticmethod |
| 514 | + async def _stream_response_and_convert_to_messages_async( |
| 515 | + stream: AsyncIterable[GenerationResponse], streaming_callback: AsyncStreamingCallbackT |
| 516 | + ) -> List[ChatMessage]: |
| 517 | + """ |
| 518 | + Streams the Google Vertex AI response and converts it to a list of `ChatMessage` instances. |
| 519 | +
|
| 520 | + :param stream: The streaming response from the Google AI request. |
| 521 | + :param streaming_callback: The handler for the streaming response. |
| 522 | + :returns: List of `ChatMessage` instances. |
| 523 | + """ |
| 524 | + |
| 525 | + text = "" |
| 526 | + tool_calls = [] |
| 527 | + chunk_dict = {} |
| 528 | + |
| 529 | + async for chunk in stream: |
| 530 | + content_to_stream = "" |
| 531 | + chunk_dict = chunk.to_dict() |
| 532 | + |
| 533 | + # Only one candidate is supported with streaming |
| 534 | + candidate = chunk_dict["candidates"][0] |
| 535 | + |
| 536 | + for part in candidate["content"]["parts"]: |
| 537 | + if new_text := part.get("text"): |
| 538 | + content_to_stream += new_text |
| 539 | + text += new_text |
| 540 | + elif new_function_call := part.get("function_call"): |
| 541 | + content_to_stream += json.dumps(dict(new_function_call)) |
| 542 | + tool_calls.append( |
| 543 | + ToolCall( |
| 544 | + tool_name=new_function_call["name"], |
| 545 | + arguments=new_function_call["args"], |
| 546 | + ) |
| 547 | + ) |
| 548 | + |
| 549 | + await streaming_callback(StreamingChunk(content=content_to_stream, meta=chunk_dict)) |
| 550 | + |
| 551 | + # store the last chunk metadata |
| 552 | + meta = chunk_dict |
| 553 | + |
| 554 | + # format the usage metadata to be compatible with OpenAI |
| 555 | + usage_metadata = meta.pop("usage_metadata", {}) |
| 556 | + |
| 557 | + openai_usage = { |
| 558 | + "prompt_tokens": usage_metadata.get("prompt_token_count", 0), |
| 559 | + "completion_tokens": usage_metadata.get("candidates_token_count", 0), |
| 560 | + "total_tokens": usage_metadata.get("total_token_count", 0), |
| 561 | + } |
| 562 | + |
| 563 | + meta["usage"] = openai_usage |
| 564 | + |
| 565 | + return [ChatMessage.from_assistant(text=text or None, meta=meta, tool_calls=tool_calls)] |
0 commit comments