|
1 | 1 | import json |
2 | 2 | from collections.abc import AsyncIterator, Iterator |
3 | | -from typing import Any, Literal, Optional, Union, get_args |
| 3 | +from typing import Any, Literal, get_args |
4 | 4 |
|
5 | 5 | from haystack import component, default_from_dict, default_to_dict, logging |
6 | 6 | from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message |
@@ -77,7 +77,7 @@ def _format_tool(tool: Tool) -> dict[str, Any]: |
77 | 77 |
|
78 | 78 | def _format_message( |
79 | 79 | message: ChatMessage, |
80 | | -) -> Union[UserChatMessageV2, AssistantChatMessageV2, SystemChatMessageV2, ToolChatMessageV2]: |
| 80 | +) -> UserChatMessageV2 | AssistantChatMessageV2 | SystemChatMessageV2 | ToolChatMessageV2: |
81 | 81 | """ |
82 | 82 | Formats a Haystack ChatMessage into Cohere's chat format. |
83 | 83 |
|
@@ -147,7 +147,7 @@ def _format_message( |
147 | 147 | raise ValueError(msg) |
148 | 148 |
|
149 | 149 | # Build multimodal content following Cohere's API specification |
150 | | - content_parts: list[Union[CohereTextContent, ImageUrlContent]] = [] |
| 150 | + content_parts: list[CohereTextContent | ImageUrlContent] = [] |
151 | 151 | for part in message._content: |
152 | 152 | if isinstance(part, TextContent) and part.text: |
153 | 153 | text_content = CohereTextContent(text=part.text) |
@@ -234,7 +234,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: |
234 | 234 | def _convert_cohere_chunk_to_streaming_chunk( |
235 | 235 | chunk: StreamedChatResponseV2, |
236 | 236 | model: str, |
237 | | - component_info: Optional[ComponentInfo] = None, |
| 237 | + component_info: ComponentInfo | None = None, |
238 | 238 | global_index: int = 0, |
239 | 239 | ) -> StreamingChunk: |
240 | 240 | """ |
@@ -518,10 +518,10 @@ def __init__( |
518 | 518 | self, |
519 | 519 | api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), |
520 | 520 | model: str = "command-a-03-2025", |
521 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
522 | | - api_base_url: Optional[str] = None, |
523 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
524 | | - tools: Optional[ToolsType] = None, |
| 521 | + streaming_callback: StreamingCallbackT | None = None, |
| 522 | + api_base_url: str | None = None, |
| 523 | + generation_kwargs: dict[str, Any] | None = None, |
| 524 | + tools: ToolsType | None = None, |
525 | 525 | **kwargs: Any, |
526 | 526 | ): |
527 | 527 | """ |
@@ -618,9 +618,9 @@ def from_dict(cls, data: dict[str, Any]) -> "CohereChatGenerator": |
618 | 618 | def run( |
619 | 619 | self, |
620 | 620 | messages: list[ChatMessage], |
621 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
622 | | - tools: Optional[ToolsType] = None, |
623 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
| 621 | + generation_kwargs: dict[str, Any] | None = None, |
| 622 | + tools: ToolsType | None = None, |
| 623 | + streaming_callback: StreamingCallbackT | None = None, |
624 | 624 | ) -> dict[str, list[ChatMessage]]: |
625 | 625 | """ |
626 | 626 | Invoke the chat endpoint based on the provided messages and generation parameters. |
@@ -685,9 +685,9 @@ def run( |
685 | 685 | async def run_async( |
686 | 686 | self, |
687 | 687 | messages: list[ChatMessage], |
688 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
689 | | - tools: Optional[ToolsType] = None, |
690 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
| 688 | + generation_kwargs: dict[str, Any] | None = None, |
| 689 | + tools: ToolsType | None = None, |
| 690 | + streaming_callback: StreamingCallbackT | None = None, |
691 | 691 | ) -> dict[str, list[ChatMessage]]: |
692 | 692 | """ |
693 | 693 | Asynchronously invoke the chat endpoint based on the provided messages and generation parameters. |
|
0 commit comments