From c45df5cb3060162b95d02d3807015cecb8a55b63 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sun, 18 May 2025 22:27:42 +0530 Subject: [PATCH 01/14] feat: Add async streaming support in hugging face generator --- .../generators/chat/hugging_face_local.py | 12 ++-- haystack/utils/hf.py | 28 ++++++++- ...c-streaming-handling-463f3a6cbd6b6f8c.yaml | 4 ++ .../chat/test_hugging_face_local.py | 63 ++++++++++++++++++- 4 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 85269c5128..ad5a9aa41a 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -11,6 +11,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback +from haystack.dataclasses.streaming_chunk import AsyncStreamingCallbackT, StreamingCallbackT from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -36,6 +37,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + AsyncHFTokenStreamingHandler, HFTokenStreamingHandler, StopWordsCriteria, convert_message_to_hf_format, @@ -130,7 +132,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, async_executor: Optional[ThreadPoolExecutor] = None, @@ -330,7 +332,7 @@ def run( self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, ): """ @@ -492,7 +494,7 @@ async def run_async( self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, tools: Optional[Union[List[Tool], Toolset]] = None, ): """ @@ -546,7 +548,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], generation_kwargs: Dict[str, Any], stop_words: Optional[List[str]], - streaming_callback: Callable[[StreamingChunk], None], + streaming_callback: StreamingCallbackT, ): """ Handles async streaming generation of responses. @@ -566,7 +568,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments ) # Set up streaming handler - generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) + generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) # Generate responses asynchronously output = await asyncio.get_running_loop().run_in_executor( diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index fcb51a8a84..850aab514f 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -4,7 +4,7 @@ import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from haystack import logging from haystack.dataclasses import ChatMessage, StreamingChunk @@ -369,3 +369,29 @@ def on_finalized_text(self, word: str, stream_end: bool = False): word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: self.token_handler(StreamingChunk(content=word_to_send)) + + class AsyncHFTokenStreamingHandler(TextStreamer): + """ + Async streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator. + + Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling + async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks. + + Do not use this class directly. + """ + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + stream_handler: Callable[[StreamingChunk], Awaitable[None]], + stop_words: Optional[List[str]] = None, + ): + super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore + self.token_handler = stream_handler + self.stop_words = stop_words or [] + + async def on_finalized_text(self, word: str, stream_end: bool = False): + """Async callback function for handling the generated text.""" + word_to_send = word + "\n" if stream_end else word + if word_to_send.strip() not in self.stop_words: + await self.token_handler(StreamingChunk(content=word_to_send)) diff --git a/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml b/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml new file mode 100644 index 0000000000..760ae97bd9 --- /dev/null +++ b/releasenotes/notes/hugging-face-async-streaming-handling-463f3a6cbd6b6f8c.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `AsyncHFTokenStreamingHandler` for async streaming support in `HuggingFaceLocalChatGenerator` diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 828a16789b..06b498b2f0 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -12,7 +12,7 @@ from haystack.components.generators.chat import HuggingFaceLocalChatGenerator from haystack.dataclasses import ChatMessage, ChatRole, ToolCall -from haystack.dataclasses.streaming_chunk import StreamingChunk +from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT from haystack.tools import Tool from haystack.utils import ComponentDevice from haystack.utils.auth import Secret @@ -485,6 +485,11 @@ def test_default_tool_parser(self, model_info_mock, tools): # Async tests + +class TestHuggingFaceLocalChatGeneratorAsync: + """Async tests for HuggingFaceLocalChatGenerator""" + + @pytest.mark.asyncio async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): """Test basic async functionality""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -498,6 +503,7 @@ async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_me assert chat_message.is_from(ChatRole.ASSISTANT) assert chat_message.text == "Berlin is cool" + @pytest.mark.asyncio async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools): """Test async functionality with tools""" generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools) @@ -516,6 +522,7 @@ async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokeniz assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Berlin"} + @pytest.mark.asyncio async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): """Test handling of multiple concurrent async requests""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -530,6 +537,7 @@ async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_to assert isinstance(result["replies"][0], ChatMessage) assert result["replies"][0].text == "Berlin is cool" + @pytest.mark.asyncio async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer): """Test error handling in async context""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") @@ -608,3 +616,56 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, to }, } assert data["init_parameters"]["tools"] == expected_tools_data + + @pytest.mark.asyncio + async def test_run_async_with_streaming_callback(self, model_info_mock): + """Test that async streaming works correctly with HuggingFaceLocalChatGenerator.""" + streaming_chunks = [] + streaming_complete = asyncio.Event() + loop = asyncio.get_running_loop() + + async def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + if chunk.content.endswith("\n"): + streaming_complete.set() + + # Create a mock pipeline that simulates streaming + mock_pipeline = Mock() + mock_tokenizer = Mock() + mock_tokenizer.apply_chat_template.return_value = "Test prompt" + mock_tokenizer.encode.return_value = [1, 2, 3] # Return a list with a length + mock_pipeline.tokenizer = mock_tokenizer + + # Mock the pipeline to return a stream of responses + def mock_generate(*args, **kwargs): + streamer = kwargs.get("streamer") + if streamer: + # Schedule the streaming callbacks in the main event loop + loop.call_soon_threadsafe( + lambda: asyncio.create_task(streamer.on_finalized_text("Hello", stream_end=False)) + ) + loop.call_soon_threadsafe( + lambda: asyncio.create_task(streamer.on_finalized_text(" world", stream_end=True)) + ) + return [{"generated_text": "Hello world"}] + + mock_pipeline.side_effect = mock_generate + + generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback) + generator.pipeline = mock_pipeline + + messages = [ChatMessage.from_user("Test message")] + response = await generator.run_async(messages) + + # Wait for streaming to complete + await streaming_complete.wait() + + assert len(streaming_chunks) == 2 + assert streaming_chunks[0].content == "Hello" + assert streaming_chunks[1].content == " world\n" + + assert isinstance(response, dict) + assert "replies" in response + assert len(response["replies"]) == 1 + assert isinstance(response["replies"][0], ChatMessage) + assert response["replies"][0].text == "Hello world" From 77f14784df8b0a44be965b2eb7dc4033641559f4 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Thu, 22 May 2025 19:56:21 +0530 Subject: [PATCH 02/14] enforce streamingcallback to be async --- haystack/components/generators/chat/hugging_face_local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index ad5a9aa41a..97112e19a3 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -568,6 +568,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments ) # Set up streaming handler + assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) # Generate responses asynchronously From 5cdb48355cde774eda56d5894882e65d33f088e9 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Thu, 22 May 2025 20:38:00 +0530 Subject: [PATCH 03/14] refactor --- haystack/components/generators/chat/hugging_face_local.py | 4 ++-- haystack/utils/hf.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 97112e19a3..5e0f8624f1 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -10,8 +10,8 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback -from haystack.dataclasses.streaming_chunk import AsyncStreamingCallbackT, StreamingCallbackT +from haystack.dataclasses import ChatMessage, ToolCall, select_streaming_callback +from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 850aab514f..86f479fd9a 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -390,7 +390,11 @@ def __init__( self.token_handler = stream_handler self.stop_words = stop_words or [] - async def on_finalized_text(self, word: str, stream_end: bool = False): + def on_finalized_text(self, word: str, stream_end: bool = False): + """Synchronous callback that returns the async handler coroutine.""" + return self.on_finalized_text_async(word, stream_end) + + async def on_finalized_text_async(self, word: str, stream_end: bool = False): """Async callback function for handling the generated text.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: From 9faa706201d48378ad966165febade73ba2bfc5a Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 24 May 2025 00:21:20 +0530 Subject: [PATCH 04/14] fix: schedule and await async task in Event Loop --- haystack/utils/hf.py | 9 +++------ .../generators/chat/test_hugging_face_local.py | 8 ++------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 86f479fd9a..541c70cae1 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import copy from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Optional, Union @@ -391,11 +392,7 @@ def __init__( self.stop_words = stop_words or [] def on_finalized_text(self, word: str, stream_end: bool = False): - """Synchronous callback that returns the async handler coroutine.""" - return self.on_finalized_text_async(word, stream_end) - - async def on_finalized_text_async(self, word: str, stream_end: bool = False): - """Async callback function for handling the generated text.""" + """Synchronous callback that schedules the async handler.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - await self.token_handler(StreamingChunk(content=word_to_send)) + asyncio.create_task(self.token_handler(StreamingChunk(content=word_to_send))) diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 06b498b2f0..6dd0e172da 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -641,12 +641,8 @@ def mock_generate(*args, **kwargs): streamer = kwargs.get("streamer") if streamer: # Schedule the streaming callbacks in the main event loop - loop.call_soon_threadsafe( - lambda: asyncio.create_task(streamer.on_finalized_text("Hello", stream_end=False)) - ) - loop.call_soon_threadsafe( - lambda: asyncio.create_task(streamer.on_finalized_text(" world", stream_end=True)) - ) + loop.call_soon_threadsafe(lambda: streamer.on_finalized_text("Hello", stream_end=False)) + loop.call_soon_threadsafe(lambda: streamer.on_finalized_text(" world", stream_end=True)) return [{"generated_text": "Hello world"}] mock_pipeline.side_effect = mock_generate From 9e2fbcd3c49dc47a164f3575f5f258e0a809952f Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 24 May 2025 00:38:50 +0530 Subject: [PATCH 05/14] unenforce typecheck --- haystack/utils/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 541c70cae1..67c4df6815 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -395,4 +395,4 @@ def on_finalized_text(self, word: str, stream_end: bool = False): """Synchronous callback that schedules the async handler.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - asyncio.create_task(self.token_handler(StreamingChunk(content=word_to_send))) + asyncio.create_task(self.token_handler(StreamingChunk(content=word_to_send))) # type: ignore # token_handler returns Awaitable[None] which is compatible with create_task at runtime From 60c907d78b0d179058aed2f790f91f6cae0d942e Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Thu, 29 May 2025 23:43:39 +0530 Subject: [PATCH 06/14] add integration test --- .../chat/test_hugging_face_local.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 6dd0e172da..884bd6ae8f 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -665,3 +665,37 @@ def mock_generate(*args, **kwargs): assert len(response["replies"]) == 1 assert isinstance(response["replies"][0], ChatMessage) assert response["replies"][0].text == "Hello world" + + @pytest.mark.integration + @pytest.mark.slow + @pytest.mark.flaky(reruns=3, reruns_delay=10) + @pytest.mark.asyncio + async def test_live_run_async_with_streaming(self, monkeypatch): + """Test async streaming with a live model.""" + monkeypatch.delenv("HF_API_TOKEN", raising=False) + + streaming_chunks = [] + streaming_complete = asyncio.Event() + + async def streaming_callback(chunk: StreamingChunk) -> None: + streaming_chunks.append(chunk) + if chunk.content.endswith("\n"): + streaming_complete.set() + + messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")] + + llm = HuggingFaceLocalChatGenerator( + model="Qwen/Qwen2.5-0.5B-Instruct", + generation_kwargs={"max_new_tokens": 50}, + streaming_callback=streaming_callback, + ) + llm.warm_up() + + response = await llm.run_async(messages=messages) + + await streaming_complete.wait() + + assert len(streaming_chunks) > 0 + assert "replies" in response + assert isinstance(response["replies"][0], ChatMessage) + assert "climate change" in response["replies"][0].text.lower() From 84a911a2e6f90584194c4a007e0e1297a4d6ef27 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 30 May 2025 00:19:00 +0530 Subject: [PATCH 07/14] After merge fixes: - fix breaking tests - added component_info to AsyncHFTokenStreamingHandler --- haystack/utils/hf.py | 8 ++++++-- .../components/generators/chat/test_hugging_face_local.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 9548123be6..d9e13b0378 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -388,13 +388,17 @@ def __init__( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stream_handler: Callable[[StreamingChunk], Awaitable[None]], stop_words: Optional[List[str]] = None, + component_info: Optional[ComponentInfo] = None, ): super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore self.token_handler = stream_handler self.stop_words = stop_words or [] + self.component_info = component_info - def on_finalized_text(self, word: str, stream_end: bool = False): + def on_finalized_text(self, word: str, stream_end: bool = False) -> None: """Synchronous callback that schedules the async handler.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - asyncio.create_task(self.token_handler(StreamingChunk(content=word_to_send))) # type: ignore # token_handler returns Awaitable[None] which is compatible with create_task at runtime + asyncio.create_task( + self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info)) + ) # type: ignore[arg-type] # token_handler returns Awaitable[None] which is compatible with create_task at runtime diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 7c1fdff584..f5e7d389d9 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -491,7 +491,7 @@ class TestHuggingFaceLocalChatGeneratorAsync: """Async tests for HuggingFaceLocalChatGenerator""" @pytest.mark.asyncio - async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): + async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages): """Test basic async functionality""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") generator.pipeline = mock_pipeline_with_tokenizer From 16486941dbc444fbcc5a1667c155429989569cca Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 30 May 2025 00:50:48 +0530 Subject: [PATCH 08/14] fix integration test --- haystack/utils/hf.py | 17 +++++++++---- .../chat/test_hugging_face_local.py | 24 ++++++++++++------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index d9e13b0378..1815a023e0 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -394,11 +394,20 @@ def __init__( self.token_handler = stream_handler self.stop_words = stop_words or [] self.component_info = component_info + self._queue: asyncio.Queue[StreamingChunk] = asyncio.Queue() def on_finalized_text(self, word: str, stream_end: bool = False) -> None: - """Synchronous callback that schedules the async handler.""" + """Synchronous callback that puts chunks in a queue.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - asyncio.create_task( - self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info)) - ) # type: ignore[arg-type] # token_handler returns Awaitable[None] which is compatible with create_task at runtime + self._queue.put_nowait(StreamingChunk(content=word_to_send, component_info=self.component_info)) + + async def process_queue(self) -> None: + """Process the queue of streaming chunks.""" + while True: + try: + chunk = await self._queue.get() + await self.token_handler(chunk) + self._queue.task_done() + except asyncio.CancelledError: + break diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index f5e7d389d9..7d83c7acbb 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -692,11 +692,19 @@ async def streaming_callback(chunk: StreamingChunk) -> None: ) llm.warm_up() - response = await llm.run_async(messages=messages) - - await streaming_complete.wait() - - assert len(streaming_chunks) > 0 - assert "replies" in response - assert isinstance(response["replies"][0], ChatMessage) - assert "climate change" in response["replies"][0].text.lower() + # Start queue processing in the background + queue_processor = asyncio.create_task(llm.pipeline.streamer.process_queue()) + try: + response = await llm.run_async(messages=messages) + await streaming_complete.wait() + + assert len(streaming_chunks) > 0 + assert "replies" in response + assert isinstance(response["replies"][0], ChatMessage) + assert "climate change" in response["replies"][0].text.lower() + finally: + queue_processor.cancel() + try: + await queue_processor + except asyncio.CancelledError: + pass From d1e42092498d9724f7514bca969037c854a38247 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 31 May 2025 13:52:40 +0530 Subject: [PATCH 09/14] refactor: improve async handling in HuggingFaceLocalChatGenerator and update tests --- .../generators/chat/hugging_face_local.py | 45 ++++++++++++------- .../chat/test_hugging_face_local.py | 44 +++++++++++------- 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index f03af2fb16..6728a99165 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -7,6 +7,7 @@ import re import sys from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging @@ -585,28 +586,38 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # get the component name and type component_info = ComponentInfo.from_component(self) assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" - generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler( - tokenizer, streaming_callback, stop_words, component_info - ) + async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) + generation_kwargs["streamer"] = async_handler - # Generate responses asynchronously - output = await asyncio.get_running_loop().run_in_executor( - self.executor, - lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init - ) + queue_processor = asyncio.create_task(async_handler.process_queue()) - replies = [o.get("generated_text", "") for o in output] + try: + # Generate responses asynchronously + output = await asyncio.get_running_loop().run_in_executor( + self.executor, + lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init + ) - # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + replies = [o.get("generated_text", "") for o in output] - chat_messages = [ - self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False) - for r_index, reply in enumerate(replies) - ] + # Remove stop words from replies if present + for stop_word in stop_words or []: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] - return {"replies": chat_messages} + chat_messages = [ + self.create_message( + reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False + ) + for r_index, reply in enumerate(replies) + ] + + return {"replies": chat_messages} + + finally: + # Clean up the queue processor + queue_processor.cancel() + with suppress(asyncio.CancelledError): + await queue_processor async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments self, diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 7d83c7acbb..dbf9edaabb 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -7,6 +7,7 @@ from typing import Optional, List from unittest.mock import Mock, patch +from haystack.utils.hf import AsyncHFTokenStreamingHandler import pytest from transformers import PreTrainedTokenizer @@ -679,12 +680,16 @@ async def test_live_run_async_with_streaming(self, monkeypatch): streaming_complete = asyncio.Event() async def streaming_callback(chunk: StreamingChunk) -> None: + """Async callback to collect streaming chunks.""" streaming_chunks.append(chunk) - if chunk.content.endswith("\n"): + # Check if this looks like the end of generation + # Most models will send a final chunk or the accumulated text will be substantial + if len(streaming_chunks) > 10 or (chunk.content and len("".join(c.content for c in streaming_chunks)) > 40): streaming_complete.set() - messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")] + messages = [ChatMessage.from_user("Please create a summary about the following topic: Capital of France")] + # Initialize the generator with streaming callback llm = HuggingFaceLocalChatGenerator( model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}, @@ -692,19 +697,24 @@ async def streaming_callback(chunk: StreamingChunk) -> None: ) llm.warm_up() - # Start queue processing in the background - queue_processor = asyncio.create_task(llm.pipeline.streamer.process_queue()) + response = await llm.run_async(messages=messages) + + # Wait for either streaming to complete or a timeout try: - response = await llm.run_async(messages=messages) - await streaming_complete.wait() - - assert len(streaming_chunks) > 0 - assert "replies" in response - assert isinstance(response["replies"][0], ChatMessage) - assert "climate change" in response["replies"][0].text.lower() - finally: - queue_processor.cancel() - try: - await queue_processor - except asyncio.CancelledError: - pass + await asyncio.wait_for(streaming_complete.wait(), timeout=30.0) + except asyncio.TimeoutError: + pass # We'll still check the results even if streaming takes longer + + assert len(streaming_chunks) > 0, "Should have received at least one streaming chunk" + assert "replies" in response + assert isinstance(response["replies"][0], ChatMessage) + assert "climate change" in response["replies"][0].text.lower() + + # Verify streaming chunks contain actual content + total_streamed_content = "".join(chunk.content for chunk in streaming_chunks) + assert len(total_streamed_content.strip()) > 0, "Streaming chunks should contain content" + + print(f"Received {len(streaming_chunks)} streaming chunks") + print(f"Total streamed content length: {len(total_streamed_content)}") + print(f"Final response length: {len(response['replies'][0].text)}") + assert "Paris" in response["replies"][0].text, "Response should mention Paris" From 33f11a0db5081c71fecff423ef9994197825984e Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 31 May 2025 14:00:06 +0530 Subject: [PATCH 10/14] fix typo --- test/components/generators/chat/test_hugging_face_local.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index dbf9edaabb..53b721e527 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -708,7 +708,7 @@ async def streaming_callback(chunk: StreamingChunk) -> None: assert len(streaming_chunks) > 0, "Should have received at least one streaming chunk" assert "replies" in response assert isinstance(response["replies"][0], ChatMessage) - assert "climate change" in response["replies"][0].text.lower() + assert "Paris" in response["replies"][0].text, "Response should mention Paris" # Verify streaming chunks contain actual content total_streamed_content = "".join(chunk.content for chunk in streaming_chunks) @@ -717,4 +717,3 @@ async def streaming_callback(chunk: StreamingChunk) -> None: print(f"Received {len(streaming_chunks)} streaming chunks") print(f"Total streamed content length: {len(total_streamed_content)}") print(f"Final response length: {len(response['replies'][0].text)}") - assert "Paris" in response["replies"][0].text, "Response should mention Paris" From cbd5682d06ae5e2889b3cd4bcb97f8a6a23cfb10 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Tue, 10 Jun 2025 09:18:19 +0530 Subject: [PATCH 11/14] address review comments --- .../generators/chat/hugging_face_local.py | 34 ++++++--- .../chat/test_hugging_face_local.py | 76 +++++++------------ 2 files changed, 48 insertions(+), 62 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 6728a99165..104190c274 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -11,7 +11,8 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk, ToolCall +from haystack.dataclasses.streaming_chunk import select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -424,8 +425,10 @@ def run( replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + if stop_words: + for stop_word in stop_words: + if stop_word in replies[0]: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( @@ -585,10 +588,11 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # get the component name and type component_info = ComponentInfo.from_component(self) - assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" + async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) generation_kwargs["streamer"] = async_handler + # Start queue processing in the background queue_processor = asyncio.create_task(async_handler.process_queue()) try: @@ -601,8 +605,10 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + if stop_words: + for stop_word in stop_words: + if stop_word in replies[0]: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( @@ -614,10 +620,12 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments return {"replies": chat_messages} finally: - # Clean up the queue processor - queue_processor.cancel() - with suppress(asyncio.CancelledError): - await queue_processor + try: + await asyncio.wait_for(queue_processor, timeout=0.1) + except asyncio.TimeoutError: + queue_processor.cancel() + with suppress(asyncio.CancelledError): + await queue_processor async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments self, @@ -661,8 +669,10 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words or []: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + if stop_words: + for stop_word in stop_words: + if stop_word in replies[0]: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 53b721e527..cb56f640d3 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -620,53 +620,41 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_with_tokenize assert data["init_parameters"]["tools"] == expected_tools_data @pytest.mark.asyncio - async def test_run_async_with_streaming_callback(self, model_info_mock): - """Test that async streaming works correctly with HuggingFaceLocalChatGenerator.""" + async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer): streaming_chunks = [] - streaming_complete = asyncio.Event() - loop = asyncio.get_running_loop() async def streaming_callback(chunk: StreamingChunk) -> None: streaming_chunks.append(chunk) - if chunk.content.endswith("\n"): - streaming_complete.set() - - # Create a mock pipeline that simulates streaming - mock_pipeline = Mock() - mock_tokenizer = Mock() - mock_tokenizer.apply_chat_template.return_value = "Test prompt" - mock_tokenizer.encode.return_value = [1, 2, 3] # Return a list with a length - mock_pipeline.tokenizer = mock_tokenizer - # Mock the pipeline to return a stream of responses - def mock_generate(*args, **kwargs): + # Create a mock that simulates streaming behavior + def mock_pipeline_call(*args, **kwargs): streamer = kwargs.get("streamer") if streamer: - # Schedule the streaming callbacks in the main event loop - loop.call_soon_threadsafe(lambda: streamer.on_finalized_text("Hello", stream_end=False)) - loop.call_soon_threadsafe(lambda: streamer.on_finalized_text(" world", stream_end=True)) - return [{"generated_text": "Hello world"}] + # Simulate streaming chunks + streamer.on_finalized_text("Berlin", stream_end=False) + streamer.on_finalized_text(" is cool", stream_end=True) + return [{"generated_text": "Berlin is cool"}] - mock_pipeline.side_effect = mock_generate + # Setup the mock pipeline with streaming simulation + mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback) - generator.pipeline = mock_pipeline + generator.pipeline = mock_pipeline_with_tokenizer messages = [ChatMessage.from_user("Test message")] response = await generator.run_async(messages) - # Wait for streaming to complete - await streaming_complete.wait() - + # Verify streaming chunks were collected assert len(streaming_chunks) == 2 - assert streaming_chunks[0].content == "Hello" - assert streaming_chunks[1].content == " world\n" + assert streaming_chunks[0].content == "Berlin" + assert streaming_chunks[1].content == " is cool\n" + # Verify the final response assert isinstance(response, dict) assert "replies" in response assert len(response["replies"]) == 1 assert isinstance(response["replies"][0], ChatMessage) - assert response["replies"][0].text == "Hello world" + assert response["replies"][0].text == "Berlin is cool" @pytest.mark.integration @pytest.mark.slow @@ -677,19 +665,10 @@ async def test_live_run_async_with_streaming(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) streaming_chunks = [] - streaming_complete = asyncio.Event() async def streaming_callback(chunk: StreamingChunk) -> None: - """Async callback to collect streaming chunks.""" streaming_chunks.append(chunk) - # Check if this looks like the end of generation - # Most models will send a final chunk or the accumulated text will be substantial - if len(streaming_chunks) > 10 or (chunk.content and len("".join(c.content for c in streaming_chunks)) > 40): - streaming_complete.set() - - messages = [ChatMessage.from_user("Please create a summary about the following topic: Capital of France")] - # Initialize the generator with streaming callback llm = HuggingFaceLocalChatGenerator( model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}, @@ -697,23 +676,20 @@ async def streaming_callback(chunk: StreamingChunk) -> None: ) llm.warm_up() - response = await llm.run_async(messages=messages) - - # Wait for either streaming to complete or a timeout - try: - await asyncio.wait_for(streaming_complete.wait(), timeout=30.0) - except asyncio.TimeoutError: - pass # We'll still check the results even if streaming takes longer + response = await llm.run_async( + messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")] + ) - assert len(streaming_chunks) > 0, "Should have received at least one streaming chunk" + # Verify that the response is not None + assert len(streaming_chunks) > 0 assert "replies" in response assert isinstance(response["replies"][0], ChatMessage) - assert "Paris" in response["replies"][0].text, "Response should mention Paris" + assert response["replies"][0].text is not None + + # Verify that the response contains the word "Paris" + assert "Paris" in response["replies"][0].text # Verify streaming chunks contain actual content total_streamed_content = "".join(chunk.content for chunk in streaming_chunks) - assert len(total_streamed_content.strip()) > 0, "Streaming chunks should contain content" - - print(f"Received {len(streaming_chunks)} streaming chunks") - print(f"Total streamed content length: {len(total_streamed_content)}") - print(f"Final response length: {len(response['replies'][0].text)}") + assert len(total_streamed_content.strip()) > 0 + assert "Paris" in total_streamed_content From 98f74f840a99c143ea198f3ffe0c7d3cd80daa64 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Tue, 10 Jun 2025 19:47:11 +0530 Subject: [PATCH 12/14] refactors --- .../generators/chat/hugging_face_local.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 92193cb9e2..9409c946af 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -426,9 +426,8 @@ def run( # Remove stop words from replies if present if stop_words: - for stop_word in stop_words or []: - if stop_word in replies[0]: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + for stop_word in stop_words: + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( @@ -589,7 +588,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # get the component name and type component_info = ComponentInfo.from_component(self) - async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) + async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore[call-arg] generation_kwargs["streamer"] = async_handler # Start queue processing in the background @@ -607,8 +606,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # Remove stop words from replies if present if stop_words: for stop_word in stop_words: - if stop_word in replies[0]: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( @@ -671,8 +669,7 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum # Remove stop words from replies if present if stop_words: for stop_word in stop_words: - if stop_word in replies[0]: - replies = [reply.replace(stop_word, "").rstrip() for reply in replies] + replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ self.create_message( From e581cad23b53d161fa58d81cfa85a9ebbb1102f1 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Tue, 10 Jun 2025 19:58:13 +0530 Subject: [PATCH 13/14] typo --- haystack/components/generators/chat/hugging_face_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 9409c946af..27b69e1890 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -588,7 +588,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments # get the component name and type component_info = ComponentInfo.from_component(self) - async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore[call-arg] + async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore generation_kwargs["streamer"] = async_handler # Start queue processing in the background From 64588451a3593cba02cc8503605f878d8847265a Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Tue, 10 Jun 2025 21:00:00 +0530 Subject: [PATCH 14/14] refactor --- haystack/components/generators/chat/hugging_face_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 27b69e1890..578037dc76 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk, ToolCall +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall from haystack.dataclasses.streaming_chunk import select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import (