|
1 | 1 | import json |
2 | 2 | import logging |
3 | 3 | import os |
4 | | -from typing import List, Optional |
| 4 | +from typing import Any, Dict, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | from langchain_core.callbacks.base import BaseCallbackHandler |
7 | 7 | from langchain_core.messages import BaseMessage |
@@ -315,14 +315,14 @@ def _extract_graph_result(self, final_chunk, graph: CompiledStateGraph): |
315 | 315 | # Fallback for any other case |
316 | 316 | return final_chunk |
317 | 317 |
|
318 | | - def _pretty_print(self, stream_chunk: tuple): |
| 318 | + def _pretty_print(self, stream_chunk: Union[Tuple[Any, Any], Dict[str, Any], Any]): |
319 | 319 | """ |
320 | 320 | Pretty print a chunk from a LangGraph stream with stream_mode="updates" and subgraphs=True. |
321 | 321 |
|
322 | 322 | Args: |
323 | 323 | stream_chunk: A tuple of (namespace, updates) from graph.astream() |
324 | 324 | """ |
325 | | - if not stream_chunk or len(stream_chunk) < 2: |
| 325 | + if not isinstance(stream_chunk, tuple) or len(stream_chunk) < 2: |
326 | 326 | return |
327 | 327 |
|
328 | 328 | node_namespace = "" |
@@ -355,7 +355,7 @@ def _pretty_print(self, stream_chunk: tuple): |
355 | 355 | if isinstance(messages, list): |
356 | 356 | for message in messages: |
357 | 357 | if isinstance(message, BaseMessage): |
358 | | - logger.info("%s", message.pretty_print()) |
| 358 | + message.pretty_print() |
359 | 359 |
|
360 | 360 | # Exclude "messages" from node_result and pretty-print the rest |
361 | 361 | metadata = {k: v for k, v in node_result.items() if k != "messages"} |
|
0 commit comments