Skip to content

Commit a28b285

Browse files
authored
feat: Add async streaming support in HuggingFaceLocalChatGenerator (#9405)
* feat: Add async streaming support in hugging face generator * enforce streamingcallback to be async * refactor * fix: schedule and await async task in Event Loop * unenforce typecheck * add integration test * After merge fixes: - fix breaking tests - added component_info to AsyncHFTokenStreamingHandler * fix integration test * refactor: improve async handling in HuggingFaceLocalChatGenerator and update tests * fix typo * address review comments * refactors * typo * refactor
1 parent f8155e1 commit a28b285

4 files changed

Lines changed: 172 additions & 24 deletions

File tree

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import re
88
import sys
99
from concurrent.futures import ThreadPoolExecutor
10+
from contextlib import suppress
1011
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1112

1213
from haystack import component, default_from_dict, default_to_dict, logging
13-
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
14+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
15+
from haystack.dataclasses.streaming_chunk import select_streaming_callback
1416
from haystack.lazy_imports import LazyImport
1517
from haystack.tools import (
1618
Tool,
@@ -37,6 +39,7 @@
3739
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
3840

3941
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
42+
AsyncHFTokenStreamingHandler,
4043
HFTokenStreamingHandler,
4144
StopWordsCriteria,
4245
convert_message_to_hf_format,
@@ -422,8 +425,9 @@ def run(
422425
replies = [o.get("generated_text", "") for o in output]
423426

424427
# Remove stop words from replies if present
425-
for stop_word in stop_words or []:
426-
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
428+
if stop_words:
429+
for stop_word in stop_words:
430+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
427431

428432
chat_messages = [
429433
self.create_message(
@@ -583,28 +587,43 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
583587

584588
# get the component name and type
585589
component_info = ComponentInfo.from_component(self)
586-
generation_kwargs["streamer"] = HFTokenStreamingHandler(
587-
tokenizer, streaming_callback, stop_words, component_info
588-
)
589590

590-
# Generate responses asynchronously
591-
output = await asyncio.get_running_loop().run_in_executor(
592-
self.executor,
593-
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
594-
)
591+
async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore
592+
generation_kwargs["streamer"] = async_handler
595593

596-
replies = [o.get("generated_text", "") for o in output]
594+
# Start queue processing in the background
595+
queue_processor = asyncio.create_task(async_handler.process_queue())
597596

598-
# Remove stop words from replies if present
599-
for stop_word in stop_words or []:
600-
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
597+
try:
598+
# Generate responses asynchronously
599+
output = await asyncio.get_running_loop().run_in_executor(
600+
self.executor,
601+
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
602+
)
601603

602-
chat_messages = [
603-
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False)
604-
for r_index, reply in enumerate(replies)
605-
]
604+
replies = [o.get("generated_text", "") for o in output]
606605

607-
return {"replies": chat_messages}
606+
# Remove stop words from replies if present
607+
if stop_words:
608+
for stop_word in stop_words:
609+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
610+
611+
chat_messages = [
612+
self.create_message(
613+
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False
614+
)
615+
for r_index, reply in enumerate(replies)
616+
]
617+
618+
return {"replies": chat_messages}
619+
620+
finally:
621+
try:
622+
await asyncio.wait_for(queue_processor, timeout=0.1)
623+
except asyncio.TimeoutError:
624+
queue_processor.cancel()
625+
with suppress(asyncio.CancelledError):
626+
await queue_processor
608627

609628
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
610629
self,
@@ -648,8 +667,9 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum
648667
replies = [o.get("generated_text", "") for o in output]
649668

650669
# Remove stop words from replies if present
651-
for stop_word in stop_words or []:
652-
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
670+
if stop_words:
671+
for stop_word in stop_words:
672+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
653673

654674
chat_messages = [
655675
self.create_message(

haystack/utils/hf.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
import copy
67
from enum import Enum
7-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
89

910
from haystack import logging
1011
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
@@ -377,3 +378,42 @@ def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
377378
content=word_to_send, index=0, start=self._call_counter == 1, component_info=self.component_info
378379
)
379380
)
381+
382+
class AsyncHFTokenStreamingHandler(TextStreamer):
383+
"""
384+
Async streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
385+
386+
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling
387+
async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks.
388+
389+
Do not use this class directly.
390+
"""
391+
392+
def __init__(
393+
self,
394+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
395+
stream_handler: Callable[[StreamingChunk], Awaitable[None]],
396+
stop_words: Optional[List[str]] = None,
397+
component_info: Optional[ComponentInfo] = None,
398+
):
399+
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
400+
self.token_handler = stream_handler
401+
self.stop_words = stop_words or []
402+
self.component_info = component_info
403+
self._queue: asyncio.Queue[StreamingChunk] = asyncio.Queue()
404+
405+
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
406+
"""Synchronous callback that puts chunks in a queue."""
407+
word_to_send = word + "\n" if stream_end else word
408+
if word_to_send.strip() not in self.stop_words:
409+
self._queue.put_nowait(StreamingChunk(content=word_to_send, component_info=self.component_info))
410+
411+
async def process_queue(self) -> None:
412+
"""Process the queue of streaming chunks."""
413+
while True:
414+
try:
415+
chunk = await self._queue.get()
416+
await self.token_handler(chunk)
417+
self._queue.task_done()
418+
except asyncio.CancelledError:
419+
break
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add `AsyncHFTokenStreamingHandler` for async streaming support in `HuggingFaceLocalChatGenerator`

test/components/generators/chat/test_hugging_face_local.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from typing import Optional, List
88
from unittest.mock import Mock, patch
99

10+
from haystack.utils.hf import AsyncHFTokenStreamingHandler
1011
import pytest
1112
from transformers import PreTrainedTokenizer
1213

1314
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
1415
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
15-
from haystack.dataclasses.streaming_chunk import StreamingChunk
16+
from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT
1617
from haystack.tools import Tool
1718
from haystack.utils import ComponentDevice
1819
from haystack.utils.auth import Secret
@@ -486,6 +487,11 @@ def test_default_tool_parser(self, model_info_mock, tools):
486487

487488
# Async tests
488489

490+
491+
class TestHuggingFaceLocalChatGeneratorAsync:
492+
"""Async tests for HuggingFaceLocalChatGenerator"""
493+
494+
@pytest.mark.asyncio
489495
async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
490496
"""Test basic async functionality"""
491497
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
@@ -499,6 +505,7 @@ async def test_run_async(self, model_info_mock, mock_pipeline_with_tokenizer, ch
499505
assert chat_message.is_from(ChatRole.ASSISTANT)
500506
assert chat_message.text == "Berlin is cool"
501507

508+
@pytest.mark.asyncio
502509
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_with_tokenizer, tools):
503510
"""Test async functionality with tools"""
504511
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
@@ -517,6 +524,7 @@ async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_with_to
517524
assert tool_call.tool_name == "weather"
518525
assert tool_call.arguments == {"city": "Berlin"}
519526

527+
@pytest.mark.asyncio
520528
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_with_tokenizer, chat_messages):
521529
"""Test handling of multiple concurrent async requests"""
522530
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
@@ -531,6 +539,7 @@ async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_wi
531539
assert isinstance(result["replies"][0], ChatMessage)
532540
assert result["replies"][0].text == "Berlin is cool"
533541

542+
@pytest.mark.asyncio
534543
async def test_async_error_handling(self, model_info_mock, mock_pipeline_with_tokenizer):
535544
"""Test error handling in async context"""
536545
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
@@ -609,3 +618,78 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_with_tokenize
609618
},
610619
}
611620
assert data["init_parameters"]["tools"] == expected_tools_data
621+
622+
@pytest.mark.asyncio
623+
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
624+
streaming_chunks = []
625+
626+
async def streaming_callback(chunk: StreamingChunk) -> None:
627+
streaming_chunks.append(chunk)
628+
629+
# Create a mock that simulates streaming behavior
630+
def mock_pipeline_call(*args, **kwargs):
631+
streamer = kwargs.get("streamer")
632+
if streamer:
633+
# Simulate streaming chunks
634+
streamer.on_finalized_text("Berlin", stream_end=False)
635+
streamer.on_finalized_text(" is cool", stream_end=True)
636+
return [{"generated_text": "Berlin is cool"}]
637+
638+
# Setup the mock pipeline with streaming simulation
639+
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
640+
641+
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
642+
generator.pipeline = mock_pipeline_with_tokenizer
643+
644+
messages = [ChatMessage.from_user("Test message")]
645+
response = await generator.run_async(messages)
646+
647+
# Verify streaming chunks were collected
648+
assert len(streaming_chunks) == 2
649+
assert streaming_chunks[0].content == "Berlin"
650+
assert streaming_chunks[1].content == " is cool\n"
651+
652+
# Verify the final response
653+
assert isinstance(response, dict)
654+
assert "replies" in response
655+
assert len(response["replies"]) == 1
656+
assert isinstance(response["replies"][0], ChatMessage)
657+
assert response["replies"][0].text == "Berlin is cool"
658+
659+
@pytest.mark.integration
660+
@pytest.mark.slow
661+
@pytest.mark.flaky(reruns=3, reruns_delay=10)
662+
@pytest.mark.asyncio
663+
async def test_live_run_async_with_streaming(self, monkeypatch):
664+
"""Test async streaming with a live model."""
665+
monkeypatch.delenv("HF_API_TOKEN", raising=False)
666+
667+
streaming_chunks = []
668+
669+
async def streaming_callback(chunk: StreamingChunk) -> None:
670+
streaming_chunks.append(chunk)
671+
672+
llm = HuggingFaceLocalChatGenerator(
673+
model="Qwen/Qwen2.5-0.5B-Instruct",
674+
generation_kwargs={"max_new_tokens": 50},
675+
streaming_callback=streaming_callback,
676+
)
677+
llm.warm_up()
678+
679+
response = await llm.run_async(
680+
messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")]
681+
)
682+
683+
# Verify that the response is not None
684+
assert len(streaming_chunks) > 0
685+
assert "replies" in response
686+
assert isinstance(response["replies"][0], ChatMessage)
687+
assert response["replies"][0].text is not None
688+
689+
# Verify that the response contains the word "Paris"
690+
assert "Paris" in response["replies"][0].text
691+
692+
# Verify streaming chunks contain actual content
693+
total_streamed_content = "".join(chunk.content for chunk in streaming_chunks)
694+
assert len(total_streamed_content.strip()) > 0
695+
assert "Paris" in total_streamed_content

0 commit comments

Comments
 (0)