feat: Add async streaming support in HuggingFaceLocalChatGenerator#9405
feat: Add async streaming support in HuggingFaceLocalChatGenerator#9405mpangrazzi merged 17 commits intodeepset-ai:mainfrom
HuggingFaceLocalChatGenerator#9405Conversation
HuggingFaceLocalChatGenerator
Pull Request Test Coverage Report for Build 15588007009Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
|
|
||
| # Set up streaming handler | ||
| generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) | ||
| generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) |
There was a problem hiding this comment.
To make mypy happy could we add an assert here asserting that streaming_callback is of type AsyncStreamingCallbackT?
or update AsyncHFTokenStreamingHandler such that the type hint for stream_handler is StreamingCallbackT
mpangrazzi
left a comment
There was a problem hiding this comment.
I've left a comment below!
| assert data["init_parameters"]["tools"] == expected_tools_data | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_async_with_streaming_callback(self, model_info_mock): |
There was a problem hiding this comment.
Could we also add an integration test for this? So an async version of test_live_run with streaming?
There was a problem hiding this comment.
I would also simplify this, removing asyncio.Event usage and reuse an already available mock.
Something like:
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)
# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"
# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Berlin is cool"WDYT?
- fix breaking tests - added component_info to AsyncHFTokenStreamingHandler
|
@sjrl: added a live integration test
|
mpangrazzi
left a comment
There was a problem hiding this comment.
I've added some comments for possible improvements. Let me know if they are clear enough!
| for r_index, reply in enumerate(replies) | ||
| ] | ||
| # Remove stop words from replies if present | ||
| for stop_word in stop_words or []: |
There was a problem hiding this comment.
What about adding a more explicit check here? (can apply also on line 427):
if stop_words:
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]| generation_kwargs["streamer"] = HFTokenStreamingHandler( | ||
| tokenizer, streaming_callback, stop_words, component_info | ||
| ) | ||
| assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" |
There was a problem hiding this comment.
Can we use select_streaming_callback utility here? (we used it in other generators)
You can get it from:
from haystack.dataclasses.streaming_chunk import select_streaming_callbackso we can avoid assert usage!
| # Clean up the queue processor | ||
| queue_processor.cancel() | ||
| with suppress(asyncio.CancelledError): | ||
| await queue_processor |
There was a problem hiding this comment.
This cleanup logic can be a bit more robust: we can add a short timeout so we can ensure queue is drained:
finally:
try:
await asyncio.wait_for(queue_processor, timeout=0.1)
except asyncio.TimeoutError:
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processorWDYT?
| @pytest.mark.slow | ||
| @pytest.mark.flaky(reruns=3, reruns_delay=10) | ||
| @pytest.mark.asyncio | ||
| async def test_live_run_async_with_streaming(self, monkeypatch): |
There was a problem hiding this comment.
I think that this test is a bit over-engineered. What about something simpler like the following? No need to use e.g. asyncio.Event to check when the streaming is done.
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.flaky(reruns=3, reruns_delay=10)
@pytest.mark.asyncio
async def test_live_run_async_with_streaming(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
llm = HuggingFaceLocalChatGenerator(
model="Qwen/Qwen2.5-0.5B-Instruct",
generation_kwargs={"max_new_tokens": 50},
streaming_callback=streaming_callback,
)
llm.warm_up()
response = await llm.run_async(
messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")]
)
# Verify that the response is not None
assert len(streaming_chunks) > 0
assert "replies" in response
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text is not None
# Verify that the response contains the word "Paris"
assert "Paris" in response["replies"][0].text
# Verify streaming chunks contain actual content
total_streamed_content = "".join(chunk.content for chunk in streaming_chunks)
assert len(total_streamed_content.strip()) > 0
assert "Paris" in total_streamed_contentWDYT?
| assert data["init_parameters"]["tools"] == expected_tools_data | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_async_with_streaming_callback(self, model_info_mock): |
There was a problem hiding this comment.
I would also simplify this, removing asyncio.Event usage and reuse an already available mock.
Something like:
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)
# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"
# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Berlin is cool"WDYT?
mpangrazzi
left a comment
There was a problem hiding this comment.
LGTM - we can address other minor nits later!
Related Issues
HuggingFaceLocalChatGeneratoradd an async version ofHFTokenStreamHandlerand update type signature for async streaming callback #9391Proposed Changes:
How did you test it?
Notes for the reviewer
Checklist
fix:,feat:,build:,chore:,ci:,docs:,style:,refactor:,perf:,test:and added!in case the PR includes breaking changes.