Skip to content

Commit 580683b

Browse files
authored
chore: improve select_streaming_callback type hints (#9513)
1 parent a28b285 commit 580683b

8 files changed

Lines changed: 75 additions & 23 deletions

File tree

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88

99
from haystack import component, default_from_dict, default_to_dict, logging
1010
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
11-
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
12-
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
11+
from haystack.dataclasses import (
12+
AsyncStreamingCallbackT,
13+
ChatMessage,
14+
ComponentInfo,
15+
StreamingCallbackT,
16+
StreamingChunk,
17+
SyncStreamingCallbackT,
18+
ToolCall,
19+
select_streaming_callback,
20+
)
1321
from haystack.lazy_imports import LazyImport
1422
from haystack.tools import (
1523
Tool,
@@ -437,7 +445,10 @@ async def run_async(
437445
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
438446

439447
def _run_streaming(
440-
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
448+
self,
449+
messages: List[Dict[str, str]],
450+
generation_kwargs: Dict[str, Any],
451+
streaming_callback: SyncStreamingCallbackT,
441452
):
442453
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
443454
messages,
@@ -501,7 +512,10 @@ def _run_non_streaming(
501512
return {"replies": [message]}
502513

503514
async def _run_streaming_async(
504-
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
515+
self,
516+
messages: List[Dict[str, str]],
517+
generation_kwargs: Dict[str, Any],
518+
streaming_callback: AsyncStreamingCallbackT,
505519
):
506520
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
507521
messages,

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1212

1313
from haystack import component, default_from_dict, default_to_dict, logging
14-
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
14+
from haystack.dataclasses import AsyncStreamingCallbackT, ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall
1515
from haystack.dataclasses.streaming_chunk import select_streaming_callback
1616
from haystack.lazy_imports import LazyImport
1717
from haystack.tools import (
@@ -566,7 +566,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
566566
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
567567
generation_kwargs: Dict[str, Any],
568568
stop_words: Optional[List[str]],
569-
streaming_callback: StreamingCallbackT,
569+
streaming_callback: AsyncStreamingCallbackT,
570570
):
571571
"""
572572
Handles async streaming generation of responses.
@@ -588,7 +588,9 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
588588
# get the component name and type
589589
component_info = ComponentInfo.from_component(self)
590590

591-
async_handler = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words, component_info) # type: ignore
591+
async_handler = AsyncHFTokenStreamingHandler(
592+
tokenizer=tokenizer, stream_handler=streaming_callback, stop_words=stop_words, component_info=component_info
593+
)
592594
generation_kwargs["streamer"] = async_handler
593595

594596
# Start queue processing in the background

haystack/components/generators/chat/openai.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,10 @@ def run(
280280

281281
if streaming_callback is not None:
282282
completions = self._handle_stream_response(
283+
# we cannot check isinstance(chat_completion, Stream) because some observability tools wrap Stream
284+
# and return a different type. See https://github.com/deepset-ai/haystack/issues/9014.
283285
chat_completion, # type: ignore
284-
streaming_callback, # type: ignore
286+
streaming_callback,
285287
)
286288

287289
else:
@@ -356,8 +358,10 @@ async def run_async(
356358

357359
if streaming_callback is not None:
358360
completions = await self._handle_async_stream_response(
361+
# we cannot check isinstance(chat_completion, AsyncStream) because some observability tools wrap
362+
# AsyncStream and return a different type. See https://github.com/deepset-ai/haystack/issues/9014.
359363
chat_completion, # type: ignore
360-
streaming_callback, # type: ignore
364+
streaming_callback,
361365
)
362366

363367
else:

haystack/components/generators/hugging_face_api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from typing import Any, Dict, Iterable, List, Optional, Union, cast
88

99
from haystack import component, default_from_dict, default_to_dict
10-
from haystack.dataclasses import ComponentInfo, StreamingCallbackT, StreamingChunk, select_streaming_callback
10+
from haystack.dataclasses import (
11+
ComponentInfo,
12+
StreamingCallbackT,
13+
StreamingChunk,
14+
SyncStreamingCallbackT,
15+
select_streaming_callback,
16+
)
1117
from haystack.lazy_imports import LazyImport
1218
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1319
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
@@ -214,13 +220,13 @@ def run(
214220
)
215221

216222
if streaming_callback is not None:
217-
return self._stream_and_build_response(hf_output, streaming_callback)
223+
return self._stream_and_build_response(hf_output=hf_output, streaming_callback=streaming_callback)
218224

219225
# mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
220226
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
221227

222228
def _stream_and_build_response(
223-
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
229+
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: SyncStreamingCallbackT
224230
):
225231
chunks: List[StreamingChunk] = []
226232
first_chunk_time = None

haystack/components/generators/hugging_face_local.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def run(
233233
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
234234
)
235235

236+
# at this point, we know that the pipeline has been initialized
237+
assert self.pipeline is not None
238+
# text-generation and text2text-generation pipelines always have a non-None tokenizer
239+
assert self.pipeline.tokenizer is not None
240+
236241
if not prompt:
237242
return {"replies": []}
238243

@@ -254,15 +259,16 @@ def run(
254259
)
255260
logger.warning(msg, num_responses=num_responses)
256261
updated_generation_kwargs["num_return_sequences"] = 1
262+
257263
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
258264
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
259-
tokenizer=self.pipeline.tokenizer, # type: ignore
265+
tokenizer=self.pipeline.tokenizer,
260266
stream_handler=streaming_callback,
261-
stop_words=self.stop_words, # type: ignore
267+
stop_words=self.stop_words,
262268
component_info=ComponentInfo.from_component(self),
263269
)
264270

265-
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) # type: ignore
271+
output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
266272
replies = [o["generated_text"] for o in output if "generated_text" in o]
267273

268274
if self.stop_words:

haystack/components/tools/tool_invoker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,11 +681,11 @@ async def run_async(
681681
start=True,
682682
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
683683
)
684-
) # type: ignore[misc] # we have checked that streaming_callback is not None and async
684+
)
685685

686686
# We stream one more chunk that contains a finish_reason if tool_messages were generated
687687
if len(tool_messages) > 0 and streaming_callback is not None:
688-
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) # type: ignore[misc] # we have checked that streaming_callback is not None and async
688+
await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))
689689

690690
return {"tool_messages": tool_messages, "state": state}
691691

haystack/dataclasses/streaming_chunk.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from dataclasses import dataclass, field
6-
from typing import Any, Awaitable, Callable, Dict, Optional, Union
6+
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Union, overload
77

88
from haystack.core.component import Component
99
from haystack.dataclasses.chat_message import ToolCallResult
@@ -104,6 +104,20 @@ def __post_init__(self):
104104
StreamingCallbackT = Union[SyncStreamingCallbackT, AsyncStreamingCallbackT]
105105

106106

107+
@overload
108+
def select_streaming_callback(
109+
init_callback: Optional[StreamingCallbackT],
110+
runtime_callback: Optional[StreamingCallbackT],
111+
requires_async: Literal[False],
112+
) -> Optional[SyncStreamingCallbackT]: ...
113+
@overload
114+
def select_streaming_callback(
115+
init_callback: Optional[StreamingCallbackT],
116+
runtime_callback: Optional[StreamingCallbackT],
117+
requires_async: Literal[True],
118+
) -> Optional[AsyncStreamingCallbackT]: ...
119+
120+
107121
def select_streaming_callback(
108122
init_callback: Optional[StreamingCallbackT], runtime_callback: Optional[StreamingCallbackT], requires_async: bool
109123
) -> Optional[StreamingCallbackT]:

haystack/utils/hf.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
import asyncio
66
import copy
77
from enum import Enum
8-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
8+
from typing import Any, Dict, List, Optional, Union
99

1010
from haystack import logging
11-
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
11+
from haystack.dataclasses import (
12+
AsyncStreamingCallbackT,
13+
ChatMessage,
14+
ComponentInfo,
15+
StreamingChunk,
16+
SyncStreamingCallbackT,
17+
)
1218
from haystack.lazy_imports import LazyImport
1319
from haystack.utils.auth import Secret
1420
from haystack.utils.device import ComponentDevice
@@ -350,15 +356,15 @@ class HFTokenStreamingHandler(TextStreamer):
350356
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
351357
352358
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
353-
of generated text via Haystack StreamingCallbackT callbacks.
359+
of generated text via Haystack SyncStreamingCallbackT callbacks.
354360
355361
Do not use this class directly.
356362
"""
357363

358364
def __init__(
359365
self,
360366
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
361-
stream_handler: StreamingCallbackT,
367+
stream_handler: SyncStreamingCallbackT,
362368
stop_words: Optional[List[str]] = None,
363369
component_info: Optional[ComponentInfo] = None,
364370
):
@@ -392,7 +398,7 @@ class AsyncHFTokenStreamingHandler(TextStreamer):
392398
def __init__(
393399
self,
394400
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
395-
stream_handler: Callable[[StreamingChunk], Awaitable[None]],
401+
stream_handler: AsyncStreamingCallbackT,
396402
stop_words: Optional[List[str]] = None,
397403
component_info: Optional[ComponentInfo] = None,
398404
):

0 commit comments

Comments
 (0)