Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.dataclasses import (
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
ToolCall,
select_streaming_callback,
)
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Expand Down Expand Up @@ -437,7 +445,10 @@ async def run_async(
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: SyncStreamingCallbackT,
):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
messages,
Expand Down Expand Up @@ -501,7 +512,10 @@ def _run_non_streaming(
return {"replies": [message]}

async def _run_streaming_async(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: AsyncStreamingCallbackT,
):
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
messages,
Expand Down
8 changes: 5 additions & 3 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
from haystack.dataclasses import AsyncStreamingCallbackT, ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a heads up I did a refactor of this component in #9455 so I'd appreciate if we could leave out the changes here and I can add them into mine once this PR is merged

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm actually fine to leave. I'll handle the merge conflict rather after this PR is merged

from haystack.dataclasses.streaming_chunk import select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Expand Down Expand Up @@ -566,7 +566,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
streaming_callback: StreamingCallbackT,
streaming_callback: AsyncStreamingCallbackT,
):
"""
Handles async streaming generation of responses.
Expand All @@ -588,7 +588,9 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
# get the component name and type
component_info = ComponentInfo.from_component(self)

async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore
async_handler = AsyncHFTokenStreamingHandler(
tokenizer=tokenizer, stream_handler=streaming_callback, stop_words=stop_words, component_info=component_info
)
generation_kwargs["streamer"] = async_handler

# Start queue processing in the background
Expand Down
8 changes: 6 additions & 2 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,10 @@ def run(

if streaming_callback is not None:
completions = self._handle_stream_response(
# we cannot check isinstance(chat_completion, Stream) because some observability tools wrap Stream
# and return a different type. See https://github.com/deepset-ai/haystack/issues/9014.
chat_completion, # type: ignore
streaming_callback, # type: ignore
streaming_callback,
)

else:
Expand Down Expand Up @@ -356,8 +358,10 @@ async def run_async(

if streaming_callback is not None:
completions = await self._handle_async_stream_response(
# we cannot check isinstance(chat_completion, AsyncStream) because some observability tools wrap
# AsyncStream and return a different type. See https://github.com/deepset-ai/haystack/issues/9014.
chat_completion, # type: ignore
streaming_callback, # type: ignore
streaming_callback,
)

else:
Expand Down
12 changes: 9 additions & 3 deletions haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from typing import Any, Dict, Iterable, List, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ComponentInfo, StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.dataclasses import (
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
select_streaming_callback,
)
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
Expand Down Expand Up @@ -214,13 +220,13 @@ def run(
)

if streaming_callback is not None:
return self._stream_and_build_response(hf_output, streaming_callback)
return self._stream_and_build_response(hf_output=hf_output, streaming_callback=streaming_callback)

# mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))

def _stream_and_build_response(
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: SyncStreamingCallbackT
):
chunks: List[StreamingChunk] = []
first_chunk_time = None
Expand Down
12 changes: 9 additions & 3 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def run(
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
)

# at this point, we know that the pipeline has been initialized
assert self.pipeline is not None
# text-generation and text2text-generation pipelines always have a non-None tokenizer
assert self.pipeline.tokenizer is not None

if not prompt:
return {"replies": []}

Expand All @@ -254,15 +259,16 @@ def run(
)
logger.warning(msg, num_responses=num_responses)
updated_generation_kwargs["num_return_sequences"] = 1

# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer=self.pipeline.tokenizer, # type: ignore
tokenizer=self.pipeline.tokenizer,
stream_handler=streaming_callback,
stop_words=self.stop_words, # type: ignore
stop_words=self.stop_words,
component_info=ComponentInfo.from_component(self),
)

output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
replies = [o["generated_text"] for o in output if "generated_text" in o]

if self.stop_words:
Expand Down
4 changes: 2 additions & 2 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,11 @@ async def run_async(
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
) # type: ignore[misc] # we have checked that streaming_callback is not None and async
)

# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) # type: ignore[misc] # we have checked that streaming_callback is not None and async
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))

return {"tool_messages": tool_messages, "state": state}

Expand Down
16 changes: 15 additions & 1 deletion haystack/dataclasses/streaming_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Union, overload

from haystack.core.component import Component
from haystack.dataclasses.chat_message import ToolCallResult
Expand Down Expand Up @@ -104,6 +104,20 @@ def __post_init__(self):
StreamingCallbackT = Union[SyncStreamingCallbackT, AsyncStreamingCallbackT]


@overload
def select_streaming_callback(
init_callback: Optional[StreamingCallbackT],
runtime_callback: Optional[StreamingCallbackT],
requires_async: Literal[False],
) -> Optional[SyncStreamingCallbackT]: ...
@overload
def select_streaming_callback(
init_callback: Optional[StreamingCallbackT],
runtime_callback: Optional[StreamingCallbackT],
requires_async: Literal[True],
) -> Optional[AsyncStreamingCallbackT]: ...


def select_streaming_callback(
init_callback: Optional[StreamingCallbackT], runtime_callback: Optional[StreamingCallbackT], requires_async: bool
) -> Optional[StreamingCallbackT]:
Expand Down
16 changes: 11 additions & 5 deletions haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
import asyncio
import copy
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from haystack import logging
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
from haystack.dataclasses import (
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
StreamingChunk,
SyncStreamingCallbackT,
)
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
Expand Down Expand Up @@ -350,15 +356,15 @@ class HFTokenStreamingHandler(TextStreamer):
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.

Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
of generated text via Haystack StreamingCallbackT callbacks.
of generated text via Haystack SyncStreamingCallbackT callbacks.

Do not use this class directly.
"""

def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: StreamingCallbackT,
stream_handler: SyncStreamingCallbackT,
stop_words: Optional[List[str]] = None,
component_info: Optional[ComponentInfo] = None,
):
Expand Down Expand Up @@ -392,7 +398,7 @@ class AsyncHFTokenStreamingHandler(TextStreamer):
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], Awaitable[None]],
stream_handler: AsyncStreamingCallbackT,
stop_words: Optional[List[str]] = None,
component_info: Optional[ComponentInfo] = None,
):
Expand Down
Loading