|
7 | 7 |
|
8 | 8 | import inspect |
9 | 9 | from dataclasses import dataclass |
10 | | -from typing import Any, Optional, Union |
| 10 | +from typing import Any |
11 | 11 |
|
12 | 12 | # Monkey patch Haystack's AgentSnapshot with our extended version |
13 | 13 | import haystack.dataclasses.breakpoints as hdb |
@@ -77,8 +77,8 @@ class _ExecutionContext(Haystack_ExecutionContext): |
77 | 77 | parameter in their `run()` and `run_async()` methods. |
78 | 78 | """ |
79 | 79 |
|
80 | | - tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None |
81 | | - confirmation_strategy_context: Optional[dict[str, Any]] = None |
| 80 | + tool_execution_decisions: list[ToolExecutionDecision] | None = None |
| 81 | + confirmation_strategy_context: dict[str, Any] | None = None |
82 | 82 |
|
83 | 83 |
|
84 | 84 | class Agent(HaystackAgent): |
@@ -136,16 +136,16 @@ def __init__( |
136 | 136 | self, |
137 | 137 | *, |
138 | 138 | chat_generator: ChatGenerator, |
139 | | - tools: Optional[ToolsType] = None, |
140 | | - system_prompt: Optional[str] = None, |
141 | | - exit_conditions: Optional[list[str]] = None, |
142 | | - state_schema: Optional[dict[str, Any]] = None, |
| 139 | + tools: ToolsType | None = None, |
| 140 | + system_prompt: str | None = None, |
| 141 | + exit_conditions: list[str] | None = None, |
| 142 | + state_schema: dict[str, Any] | None = None, |
143 | 143 | max_agent_steps: int = 100, |
144 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
| 144 | + streaming_callback: StreamingCallbackT | None = None, |
145 | 145 | raise_on_tool_invocation_failure: bool = False, |
146 | | - confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None, |
147 | | - tool_invoker_kwargs: Optional[dict[str, Any]] = None, |
148 | | - chat_message_store: Optional[ChatMessageStore] = None, |
| 146 | + confirmation_strategies: dict[str, ConfirmationStrategy] | None = None, |
| 147 | + tool_invoker_kwargs: dict[str, Any] | None = None, |
| 148 | + chat_message_store: ChatMessageStore | None = None, |
149 | 149 | ) -> None: |
150 | 150 | """ |
151 | 151 | Initialize the agent component. |
@@ -190,14 +190,14 @@ def __init__( |
190 | 190 | def _initialize_fresh_execution( |
191 | 191 | self, |
192 | 192 | messages: list[ChatMessage], |
193 | | - streaming_callback: Optional[StreamingCallbackT], |
| 193 | + streaming_callback: StreamingCallbackT | None, |
194 | 194 | requires_async: bool, |
195 | 195 | *, |
196 | | - system_prompt: Optional[str] = None, |
197 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
198 | | - tools: Optional[Union[ToolsType, list[str]]] = None, |
199 | | - confirmation_strategy_context: Optional[dict[str, Any]] = None, |
200 | | - chat_message_store_kwargs: Optional[dict[str, Any]] = None, |
| 196 | + system_prompt: str | None = None, |
| 197 | + generation_kwargs: dict[str, Any] | None = None, |
| 198 | + tools: ToolsType | list[str] | None = None, |
| 199 | + confirmation_strategy_context: dict[str, Any] | None = None, |
| 200 | + chat_message_store_kwargs: dict[str, Any] | None = None, |
201 | 201 | **kwargs: dict[str, Any], |
202 | 202 | ) -> _ExecutionContext: |
203 | 203 | """ |
@@ -264,12 +264,12 @@ def _initialize_fresh_execution( |
264 | 264 | def _initialize_from_snapshot( # type: ignore[override] |
265 | 265 | self, |
266 | 266 | snapshot: AgentSnapshot, |
267 | | - streaming_callback: Optional[StreamingCallbackT], |
| 267 | + streaming_callback: StreamingCallbackT | None, |
268 | 268 | requires_async: bool, |
269 | 269 | *, |
270 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
271 | | - tools: Optional[Union[ToolsType, list[str]]] = None, |
272 | | - confirmation_strategy_context: Optional[dict[str, Any]] = None, |
| 270 | + generation_kwargs: dict[str, Any] | None = None, |
| 271 | + tools: ToolsType | list[str] | None = None, |
| 272 | + confirmation_strategy_context: dict[str, Any] | None = None, |
273 | 273 | ) -> _ExecutionContext: |
274 | 274 | """ |
275 | 275 | Initialize execution context from an AgentSnapshot. |
@@ -320,15 +320,15 @@ def _initialize_from_snapshot( # type: ignore[override] |
320 | 320 | def run( # type: ignore[override] # noqa: PLR0915 PLR0912 |
321 | 321 | self, |
322 | 322 | messages: list[ChatMessage], |
323 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
| 323 | + streaming_callback: StreamingCallbackT | None = None, |
324 | 324 | *, |
325 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
326 | | - break_point: Optional[AgentBreakpoint] = None, |
327 | | - snapshot: Optional[AgentSnapshot] = None, |
328 | | - system_prompt: Optional[str] = None, |
329 | | - tools: Optional[Union[ToolsType, list[str]]] = None, |
330 | | - confirmation_strategy_context: Optional[dict[str, Any]] = None, |
331 | | - chat_message_store_kwargs: Optional[dict[str, Any]] = None, |
| 325 | + generation_kwargs: dict[str, Any] | None = None, |
| 326 | + break_point: AgentBreakpoint | None = None, |
| 327 | + snapshot: AgentSnapshot | None = None, |
| 328 | + system_prompt: str | None = None, |
| 329 | + tools: ToolsType | list[str] | None = None, |
| 330 | + confirmation_strategy_context: dict[str, Any] | None = None, |
| 331 | + chat_message_store_kwargs: dict[str, Any] | None = None, |
332 | 332 | **kwargs: Any, |
333 | 333 | ) -> dict[str, Any]: |
334 | 334 | """ |
@@ -558,15 +558,15 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912 |
558 | 558 | async def run_async( # type: ignore[override] # noqa: PLR0915 |
559 | 559 | self, |
560 | 560 | messages: list[ChatMessage], |
561 | | - streaming_callback: Optional[StreamingCallbackT] = None, |
| 561 | + streaming_callback: StreamingCallbackT | None = None, |
562 | 562 | *, |
563 | | - generation_kwargs: Optional[dict[str, Any]] = None, |
564 | | - break_point: Optional[AgentBreakpoint] = None, |
565 | | - snapshot: Optional[AgentSnapshot] = None, |
566 | | - system_prompt: Optional[str] = None, |
567 | | - tools: Optional[Union[ToolsType, list[str]]] = None, |
568 | | - confirmation_strategy_context: Optional[dict[str, Any]] = None, |
569 | | - chat_message_store_kwargs: Optional[dict[str, Any]] = None, |
| 563 | + generation_kwargs: dict[str, Any] | None = None, |
| 564 | + break_point: AgentBreakpoint | None = None, |
| 565 | + snapshot: AgentSnapshot | None = None, |
| 566 | + system_prompt: str | None = None, |
| 567 | + tools: ToolsType | list[str] | None = None, |
| 568 | + confirmation_strategy_context: dict[str, Any] | None = None, |
| 569 | + chat_message_store_kwargs: dict[str, Any] | None = None, |
570 | 570 | **kwargs: Any, |
571 | 571 | ) -> dict[str, Any]: |
572 | 572 | """ |
|
0 commit comments