-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: Add async streaming support in HuggingFaceLocalChatGenerator
#9405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
c45df5c
77f1478
5cdb483
9faa706
9e2fbcd
60c907d
c966c81
84a911a
1648694
d1e4209
33f11a0
cbd5682
f1c2f7c
98f74f8
e581cad
6458845
c42038b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,8 @@ | |
| from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast | ||
|
|
||
| from haystack import component, default_from_dict, default_to_dict, logging | ||
| from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback | ||
| from haystack.dataclasses import ChatMessage, ToolCall, select_streaming_callback | ||
| from haystack.dataclasses.streaming_chunk import StreamingCallbackT | ||
| from haystack.lazy_imports import LazyImport | ||
| from haystack.tools import ( | ||
| Tool, | ||
|
|
@@ -36,6 +37,7 @@ | |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | ||
|
|
||
| from haystack.utils.hf import ( # pylint: disable=ungrouped-imports | ||
| AsyncHFTokenStreamingHandler, | ||
| HFTokenStreamingHandler, | ||
| StopWordsCriteria, | ||
| convert_message_to_hf_format, | ||
|
|
@@ -130,7 +132,7 @@ def __init__( # pylint: disable=too-many-positional-arguments | |
| generation_kwargs: Optional[Dict[str, Any]] = None, | ||
| huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, | ||
| stop_words: Optional[List[str]] = None, | ||
| streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
| streaming_callback: Optional[StreamingCallbackT] = None, | ||
| tools: Optional[Union[List[Tool], Toolset]] = None, | ||
| tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, | ||
| async_executor: Optional[ThreadPoolExecutor] = None, | ||
|
|
@@ -330,7 +332,7 @@ def run( | |
| self, | ||
| messages: List[ChatMessage], | ||
| generation_kwargs: Optional[Dict[str, Any]] = None, | ||
| streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
| streaming_callback: Optional[StreamingCallbackT] = None, | ||
| tools: Optional[Union[List[Tool], Toolset]] = None, | ||
| ): | ||
| """ | ||
|
|
@@ -492,7 +494,7 @@ async def run_async( | |
| self, | ||
| messages: List[ChatMessage], | ||
| generation_kwargs: Optional[Dict[str, Any]] = None, | ||
| streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
| streaming_callback: Optional[StreamingCallbackT] = None, | ||
| tools: Optional[Union[List[Tool], Toolset]] = None, | ||
| ): | ||
| """ | ||
|
|
@@ -546,7 +548,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments | |
| tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], | ||
| generation_kwargs: Dict[str, Any], | ||
| stop_words: Optional[List[str]], | ||
| streaming_callback: Callable[[StreamingChunk], None], | ||
| streaming_callback: StreamingCallbackT, | ||
| ): | ||
| """ | ||
| Handles async streaming generation of responses. | ||
|
|
@@ -566,7 +568,8 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments | |
| ) | ||
|
|
||
| # Set up streaming handler | ||
| generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) | ||
| assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" | ||
| generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make mypy happy could we add an assert here asserting that or update |
||
|
|
||
| # Generate responses asynchronously | ||
| output = await asyncio.get_running_loop().run_in_executor( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| --- | ||
| features: | ||
| - | | ||
| Add `AsyncHFTokenStreamingHandler` for async streaming support in `HuggingFaceLocalChatGenerator` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
|
|
||
| from haystack.components.generators.chat import HuggingFaceLocalChatGenerator | ||
| from haystack.dataclasses import ChatMessage, ChatRole, ToolCall | ||
| from haystack.dataclasses.streaming_chunk import StreamingChunk | ||
| from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT | ||
| from haystack.tools import Tool | ||
| from haystack.utils import ComponentDevice | ||
| from haystack.utils.auth import Secret | ||
|
|
@@ -485,6 +485,11 @@ def test_default_tool_parser(self, model_info_mock, tools): | |
|
|
||
| # Async tests | ||
|
|
||
|
|
||
| class TestHuggingFaceLocalChatGeneratorAsync: | ||
| """Async tests for HuggingFaceLocalChatGenerator""" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): | ||
| """Test basic async functionality""" | ||
| generator = HuggingFaceLocalChatGenerator(model="mocked-model") | ||
|
|
@@ -498,6 +503,7 @@ async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_me | |
| assert chat_message.is_from(ChatRole.ASSISTANT) | ||
| assert chat_message.text == "Berlin is cool" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools): | ||
| """Test async functionality with tools""" | ||
| generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools) | ||
|
|
@@ -516,6 +522,7 @@ async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokeniz | |
| assert tool_call.tool_name == "weather" | ||
| assert tool_call.arguments == {"city": "Berlin"} | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages): | ||
| """Test handling of multiple concurrent async requests""" | ||
| generator = HuggingFaceLocalChatGenerator(model="mocked-model") | ||
|
|
@@ -530,6 +537,7 @@ async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_to | |
| assert isinstance(result["replies"][0], ChatMessage) | ||
| assert result["replies"][0].text == "Berlin is cool" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer): | ||
| """Test error handling in async context""" | ||
| generator = HuggingFaceLocalChatGenerator(model="mocked-model") | ||
|
|
@@ -608,3 +616,52 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, to | |
| }, | ||
| } | ||
| assert data["init_parameters"]["tools"] == expected_tools_data | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_async_with_streaming_callback(self, model_info_mock): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also add an integration test for this? So an async version of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would also simplify this, removing @pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)
# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"
# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Berlin is cool"WDYT? |
||
| """Test that async streaming works correctly with HuggingFaceLocalChatGenerator.""" | ||
| streaming_chunks = [] | ||
| streaming_complete = asyncio.Event() | ||
| loop = asyncio.get_running_loop() | ||
|
|
||
| async def streaming_callback(chunk: StreamingChunk) -> None: | ||
| streaming_chunks.append(chunk) | ||
| if chunk.content.endswith("\n"): | ||
| streaming_complete.set() | ||
|
|
||
| # Create a mock pipeline that simulates streaming | ||
| mock_pipeline = Mock() | ||
| mock_tokenizer = Mock() | ||
| mock_tokenizer.apply_chat_template.return_value = "Test prompt" | ||
| mock_tokenizer.encode.return_value = [1, 2, 3] # Return a list with a length | ||
| mock_pipeline.tokenizer = mock_tokenizer | ||
|
|
||
| # Mock the pipeline to return a stream of responses | ||
| def mock_generate(*args, **kwargs): | ||
| streamer = kwargs.get("streamer") | ||
| if streamer: | ||
| # Schedule the streaming callbacks in the main event loop | ||
| loop.call_soon_threadsafe(lambda: streamer.on_finalized_text("Hello", stream_end=False)) | ||
| loop.call_soon_threadsafe(lambda: streamer.on_finalized_text(" world", stream_end=True)) | ||
| return [{"generated_text": "Hello world"}] | ||
|
|
||
| mock_pipeline.side_effect = mock_generate | ||
|
|
||
| generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback) | ||
| generator.pipeline = mock_pipeline | ||
|
|
||
| messages = [ChatMessage.from_user("Test message")] | ||
| response = await generator.run_async(messages) | ||
|
|
||
| # Wait for streaming to complete | ||
| await streaming_complete.wait() | ||
|
|
||
| assert len(streaming_chunks) == 2 | ||
| assert streaming_chunks[0].content == "Hello" | ||
| assert streaming_chunks[1].content == " world\n" | ||
|
|
||
| assert isinstance(response, dict) | ||
| assert "replies" in response | ||
| assert len(response["replies"]) == 1 | ||
| assert isinstance(response["replies"][0], ChatMessage) | ||
| assert response["replies"][0].text == "Hello world" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
select_streaming_callbackutility here? (we used it in other generators)You can get it from:
so we can avoid
assertusage!