diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index a174bfb930..578037dc76 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -7,10 +7,12 @@ import re import sys from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall +from haystack.dataclasses.streaming_chunk import select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -37,6 +39,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + AsyncHFTokenStreamingHandler, HFTokenStreamingHandler, StopWordsCriteria, convert_message_to_hf_format, @@ -422,8 +425,9 @@ def run( replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + if stop_words: + for stop_word in stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( @@ -583,28 +587,43 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # get the component name and type component_info = ComponentInfo.from_component(self) - generation_kwargs["streamer"] = HFTokenStreamingHandler( - tokenizer, streaming_callback, stop_words, component_info - ) - # Generate responses asynchronously - output = await asyncio.get_running_loop().run_in_executor( - self.executor, - lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init - ) + async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore + generation_kwargs["streamer"] = async_handler - replies = [o.get("generated_text", "") for o in output] + # Start queue processing in the background + queue_processor = asyncio.create_task(async_handler.process_queue()) - # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + try: + # Generate responses asynchronously + output = await asyncio.get_running_loop().run_in_executor( + self.executor, + lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init + ) - chat_messages = [ - self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False) - for r_index, reply in enumerate(replies) - ] + replies = [o.get("generated_text", "") for o in output] - return {"replies": chat_messages} + # Remove stop words from replies if present + if stop_words: + for stop_word in stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + + chat_messages = [ + self.create_message( + reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False + ) + for r_index, reply in enumerate(replies) + ] + + return {"replies": chat_messages} + + finally: + try: + await asyncio.wait_for(queue_processor, timeout=0.1) + except asyncio.TimeoutError: + queue_processor.cancel() + with suppress(asyncio.CancelledError): + await queue_processor async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments self, @@ -648,8 +667,9 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + if stop_words: + for stop_word in stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index e616348aef..31a422dd49 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -2,9 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import copy from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from haystack import logging from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk @@ -377,3 +378,42 @@ def on_finalized_text(self, word: str, stream_end: bool = False) -> None: content=word_to_send, index=0, start=self._call_counter == 1, component_info=self.component_info ) ) + + class AsyncHFTokenStreamingHandler(TextStreamer): + """ + Async streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator. + + Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling + async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks. + + Do not use this class directly. + """ + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + stream_handler: Callable[[StreamingChunk], Awaitable[None]], + stop_words: Optional[List[str]] = None, + component_info: Optional[ComponentInfo] = None, + ): + super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore + self.token_handler = stream_handler + self.stop_words = stop_words or [] + self.component_info = component_info + self._queue: asyncio.Queue[StreamingChunk] = asyncio.Queue() + + def on_finalized_text(self, word: str, stream_end: bool = False) -> None: + """Synchronous callback that puts chunks in a queue.""" + word_to_send = word + "\n" if stream_end else word + if word_to_send.strip() not in self.stop_words: + self._queue.put_nowait(StreamingChunk(content=word_to_send, component_info=self.component_info)) + + async def process_queue(self) -> None: + """Process the queue of streaming chunks.""" + while True: + try: + chunk = await self._queue.get() + await self.token_handler(chunk) + self._queue.task_done() + except asyncio.CancelledError: + break diff --git a/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml b/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml new file mode 100644 index 0000000000..760ae97bd9 --- /dev/null +++ b/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `AsyncHFTokenStreamingHandler` for async streaming support in `HuggingFaceLocalChatGenerator` diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 9c0b98fba8..cb56f640d3 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -7,12 +7,13 @@ from typing import Optional, List from unittest.mock import Mock, patch +from haystack.utils.hf import AsyncHFTokenStreamingHandler import pytest from transformers import PreTrainedTokenizer from haystack.components.generators.chat import HuggingFaceLocalChatGenerator from haystack.dataclasses import ChatMessage, ChatRole, ToolCall -from haystack.dataclasses.streaming_chunk import StreamingChunk +from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT from haystack.tools import Tool from haystack.utils import ComponentDevice from haystack.utils.auth import Secret @@ -486,6 +487,11 @@ def test_default_tool_parser(self, model_info_mock, tools): # Async tests + +class TestHuggingFaceLocalChatGeneratorAsync: + """Async tests for HuggingFaceLocalChatGenerator""" + + @pytest.mark.asyncio async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages): """Test basic async functionality""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -499,6 +505,7 @@ async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, ch assert chat_message.is_from(ChatRole.ASSISTANT) assert chat_message.text == "Berlin is cool" + @pytest.mark.asyncio async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_with_tokenizer, tools): """Test async functionality with tools""" generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools) @@ -517,6 +524,7 @@ async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_with_to assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Berlin"} + @pytest.mark.asyncio async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages): """Test handling of multiple concurrent async requests""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -531,6 +539,7 @@ async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_wi assert isinstance(result["replies"][0], ChatMessage) assert result["replies"][0].text == "Berlin is cool" + @pytest.mark.asyncio async def test_async_error_handling(self, model_info_mock, mock_pipeline_with_tokenizer): """Test error handling in async context""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -609,3 +618,78 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_with_tokenize }, } assert data["init_parameters"]["tools"] == expected_tools_data + + @pytest.mark.asyncio + async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer): + streaming_chunks = [] + + async def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + + # Create a mock that simulates streaming behavior + def mock_pipeline_call(*args, **kwargs): + streamer = kwargs.get("streamer") + if streamer: + # Simulate streaming chunks + streamer.on_finalized_text("Berlin", stream_end=False) + streamer.on_finalized_text(" is cool", stream_end=True) + return [{"generated_text": "Berlin is cool"}] + + # Setup the mock pipeline with streaming simulation + mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call + + generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback) + generator.pipeline = mock_pipeline_with_tokenizer + + messages = [ChatMessage.from_user("Test message")] + response = await generator.run_async(messages) + + # Verify streaming chunks were collected + assert len(streaming_chunks) == 2 + assert streaming_chunks[0].content == "Berlin" + assert streaming_chunks[1].content == " is cool\n" + + # Verify the final response + assert isinstance(response, dict) + assert "replies" in response + assert len(response["replies"]) == 1 + assert isinstance(response["replies"][0], ChatMessage) + assert response["replies"][0].text == "Berlin is cool" + + @pytest.mark.integration + @pytest.mark.slow + @pytest.mark.flaky(reruns=3, reruns_delay=10) + @pytest.mark.asyncio + async def test_live_run_async_with_streaming(self, monkeypatch): + """Test async streaming with a live model.""" + monkeypatch.delenv("HF_API_TOKEN", raising=False) + + streaming_chunks = [] + + async def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + + llm = HuggingFaceLocalChatGenerator( + model="Qwen/Qwen2.5-0.5B-Instruct", + generation_kwargs={"max_new_tokens": 50}, + streaming_callback=streaming_callback, + ) + llm.warm_up() + + response = await llm.run_async( + messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")] + ) + + # Verify that the response is not None + assert len(streaming_chunks) > 0 + assert "replies" in response + assert isinstance(response["replies"][0], ChatMessage) + assert response["replies"][0].text is not None + + # Verify that the response contains the word "Paris" + assert "Paris" in response["replies"][0].text + + # Verify streaming chunks contain actual content + total_streamed_content = "".join(chunk.content for chunk in streaming_chunks) + assert len(total_streamed_content.strip()) > 0 + assert "Paris" in total_streamed_content