Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ToolCall, select_streaming_callback
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Expand All @@ -36,6 +37,7 @@
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
AsyncHFTokenStreamingHandler,
HFTokenStreamingHandler,
StopWordsCriteria,
convert_message_to_hf_format,
Expand Down Expand Up @@ -130,7 +132,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None,
Expand Down Expand Up @@ -330,7 +332,7 @@ def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Expand Down Expand Up @@ -492,7 +494,7 @@ async def run_async(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Expand Down Expand Up @@ -546,7 +548,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
streaming_callback: Callable[[StreamingChunk], None],
streaming_callback: StreamingCallbackT,
):
"""
Handles async streaming generation of responses.
Expand All @@ -566,7 +568,8 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
)

# Set up streaming handler
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use select_streaming_callback utility here? (we used it in other generators)

You can get it from:

from haystack.dataclasses.streaming_chunk import select_streaming_callback

so we can avoid assert usage!

generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make mypy happy could we add an assert here asserting that streaming_callback is of type AsyncStreamingCallbackT?

or update AsyncHFTokenStreamingHandler such that the type hint for stream_handler is StreamingCallbackT


# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
Expand Down
29 changes: 28 additions & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import copy
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union

from haystack import logging
from haystack.dataclasses import ChatMessage, StreamingChunk
Expand Down Expand Up @@ -369,3 +370,29 @@ def on_finalized_text(self, word: str, stream_end: bool = False):
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))

class AsyncHFTokenStreamingHandler(TextStreamer):
"""
Async streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.

Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling
async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks.

Do not use this class directly.
"""

def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], Awaitable[None]],
stop_words: Optional[List[str]] = None,
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []

def on_finalized_text(self, word: str, stream_end: bool = False):
Comment thread
sjrl marked this conversation as resolved.
Outdated
"""Synchronous callback that schedules the async handler."""
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
asyncio.create_task(self.token_handler(StreamingChunk(content=word_to_send))) # type: ignore # token_handler returns Awaitable[None] which is compatible with create_task at runtime
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add `AsyncHFTokenStreamingHandler` for async streaming support in `HuggingFaceLocalChatGenerator`
59 changes: 58 additions & 1 deletion test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT
from haystack.tools import Tool
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
Expand Down Expand Up @@ -485,6 +485,11 @@ def test_default_tool_parser(self, model_info_mock, tools):

# Async tests


class TestHuggingFaceLocalChatGeneratorAsync:
"""Async tests for HuggingFaceLocalChatGenerator"""

@pytest.mark.asyncio
async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
"""Test basic async functionality"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
Expand All @@ -498,6 +503,7 @@ async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_me
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"

@pytest.mark.asyncio
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools):
"""Test async functionality with tools"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
Expand All @@ -516,6 +522,7 @@ async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokeniz
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Berlin"}

@pytest.mark.asyncio
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
"""Test handling of multiple concurrent async requests"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
Expand All @@ -530,6 +537,7 @@ async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_to
assert isinstance(result["replies"][0], ChatMessage)
assert result["replies"][0].text == "Berlin is cool"

@pytest.mark.asyncio
async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer):
"""Test error handling in async context"""
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
Expand Down Expand Up @@ -608,3 +616,52 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, to
},
}
assert data["init_parameters"]["tools"] == expected_tools_data

@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add an integration test for this? So an async version of test_live_run with streaming?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also simplify this, removing asyncio.Event usage and reuse an already available mock.
Something like:

@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
    streaming_chunks = []

    async def streaming_callback(chunk: StreamingChunk) -> None:
        streaming_chunks.append(chunk)

    # Create a mock that simulates streaming behavior
    def mock_pipeline_call(*args, **kwargs):
        streamer = kwargs.get("streamer")
        if streamer:
            # Simulate streaming chunks
            streamer.on_finalized_text("Berlin", stream_end=False)
            streamer.on_finalized_text(" is cool", stream_end=True)
        return [{"generated_text": "Berlin is cool"}]

    # Setup the mock pipeline with streaming simulation
    mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call

    generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
    generator.pipeline = mock_pipeline_with_tokenizer

    messages = [ChatMessage.from_user("Test message")]
    response = await generator.run_async(messages)

    # Verify streaming chunks were collected
    assert len(streaming_chunks) == 2
    assert streaming_chunks[0].content == "Berlin"
    assert streaming_chunks[1].content == " is cool\n"

    # Verify the final response
    assert isinstance(response, dict)
    assert "replies" in response
    assert len(response["replies"]) == 1
    assert isinstance(response["replies"][0], ChatMessage)
    assert response["replies"][0].text == "Berlin is cool"

WDYT?

"""Test that async streaming works correctly with HuggingFaceLocalChatGenerator."""
streaming_chunks = []
streaming_complete = asyncio.Event()
loop = asyncio.get_running_loop()

async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
if chunk.content.endswith("\n"):
streaming_complete.set()

# Create a mock pipeline that simulates streaming
mock_pipeline = Mock()
mock_tokenizer = Mock()
mock_tokenizer.apply_chat_template.return_value = "Test prompt"
mock_tokenizer.encode.return_value = [1, 2, 3] # Return a list with a length
mock_pipeline.tokenizer = mock_tokenizer

# Mock the pipeline to return a stream of responses
def mock_generate(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Schedule the streaming callbacks in the main event loop
loop.call_soon_threadsafe(lambda: streamer.on_finalized_text("Hello", stream_end=False))
loop.call_soon_threadsafe(lambda: streamer.on_finalized_text(" world", stream_end=True))
return [{"generated_text": "Hello world"}]

mock_pipeline.side_effect = mock_generate

generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline

messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)

# Wait for streaming to complete
await streaming_complete.wait()

assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Hello"
assert streaming_chunks[1].content == " world\n"

assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Hello world"
Loading