Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk, ToolCall
from haystack.dataclasses.streaming_chunk import select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Expand Down Expand Up @@ -424,8 +425,10 @@ def run(
replies = [o.get("generated_text", "") for o in output]

# Remove stop words from replies if present
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
if stop_words:
for stop_word in stop_words:
if stop_word in replies[0]:
Comment thread
sjrl marked this conversation as resolved.
Outdated
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

chat_messages = [
self.create_message(
Expand Down Expand Up @@ -585,10 +588,11 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments

# get the component name and type
component_info = ComponentInfo.from_component(self)
assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous"

async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info)
generation_kwargs["streamer"] = async_handler

# Start queue processing in the background
queue_processor = asyncio.create_task(async_handler.process_queue())

try:
Expand All @@ -601,8 +605,10 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
replies = [o.get("generated_text", "") for o in output]

# Remove stop words from replies if present
for stop_word in stop_words or []:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
if stop_words:
for stop_word in stop_words:
if stop_word in replies[0]:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

chat_messages = [
self.create_message(
Expand All @@ -614,10 +620,12 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
return {"replies": chat_messages}

finally:
# Clean up the queue processor
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processor
try:
await asyncio.wait_for(queue_processor, timeout=0.1)
except asyncio.TimeoutError:
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processor

async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
self,
Expand Down Expand Up @@ -661,8 +669,10 @@ async def _run_non_streaming_async( # pylint: disable=too-many-positional-argum
replies = [o.get("generated_text", "") for o in output]

# Remove stop words from replies if present
for stop_word in stop_words or []:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
if stop_words:
for stop_word in stop_words:
if stop_word in replies[0]:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

chat_messages = [
self.create_message(
Expand Down
76 changes: 26 additions & 50 deletions test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,53 +620,41 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_with_tokenize
assert data["init_parameters"]["tools"] == expected_tools_data

@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock):
"""Test that async streaming works correctly with HuggingFaceLocalChatGenerator."""
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
streaming_complete = asyncio.Event()
loop = asyncio.get_running_loop()

async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
if chunk.content.endswith("\n"):
streaming_complete.set()

# Create a mock pipeline that simulates streaming
mock_pipeline = Mock()
mock_tokenizer = Mock()
mock_tokenizer.apply_chat_template.return_value = "Test prompt"
mock_tokenizer.encode.return_value = [1, 2, 3] # Return a list with a length
mock_pipeline.tokenizer = mock_tokenizer

# Mock the pipeline to return a stream of responses
def mock_generate(*args, **kwargs):
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Schedule the streaming callbacks in the main event loop
loop.call_soon_threadsafe(lambda: streamer.on_finalized_text("Hello", stream_end=False))
loop.call_soon_threadsafe(lambda: streamer.on_finalized_text(" world", stream_end=True))
return [{"generated_text": "Hello world"}]
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]

mock_pipeline.side_effect = mock_generate
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call

generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline
generator.pipeline = mock_pipeline_with_tokenizer

messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)

# Wait for streaming to complete
await streaming_complete.wait()

# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Hello"
assert streaming_chunks[1].content == " world\n"
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"

# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Hello world"
assert response["replies"][0].text == "Berlin is cool"

@pytest.mark.integration
@pytest.mark.slow
Expand All @@ -677,43 +665,31 @@ async def test_live_run_async_with_streaming(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

streaming_chunks = []
streaming_complete = asyncio.Event()

async def streaming_callback(chunk: StreamingChunk) -> None:
"""Async callback to collect streaming chunks."""
streaming_chunks.append(chunk)
# Check if this looks like the end of generation
# Most models will send a final chunk or the accumulated text will be substantial
if len(streaming_chunks) > 10 or (chunk.content and len("".join(c.content for c in streaming_chunks)) > 40):
streaming_complete.set()

messages = [ChatMessage.from_user("Please create a summary about the following topic: Capital of France")]

# Initialize the generator with streaming callback
llm = HuggingFaceLocalChatGenerator(
model="Qwen/Qwen2.5-0.5B-Instruct",
generation_kwargs={"max_new_tokens": 50},
streaming_callback=streaming_callback,
)
llm.warm_up()

response = await llm.run_async(messages=messages)

# Wait for either streaming to complete or a timeout
try:
await asyncio.wait_for(streaming_complete.wait(), timeout=30.0)
except asyncio.TimeoutError:
pass # We'll still check the results even if streaming takes longer
response = await llm.run_async(
messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")]
)

assert len(streaming_chunks) > 0, "Should have received at least one streaming chunk"
# Verify that the response is not None
assert len(streaming_chunks) > 0
assert "replies" in response
assert isinstance(response["replies"][0], ChatMessage)
assert "Paris" in response["replies"][0].text, "Response should mention Paris"
assert response["replies"][0].text is not None

# Verify that the response contains the word "Paris"
assert "Paris" in response["replies"][0].text

# Verify streaming chunks contain actual content
total_streamed_content = "".join(chunk.content for chunk in streaming_chunks)
assert len(total_streamed_content.strip()) > 0, "Streaming chunks should contain content"

print(f"Received {len(streaming_chunks)} streaming chunks")
print(f"Total streamed content length: {len(total_streamed_content)}")
print(f"Final response length: {len(response['replies'][0].text)}")
assert len(total_streamed_content.strip()) > 0
assert "Paris" in total_streamed_content