|
4 | 4 | import os |
5 | 5 | import string |
6 | 6 | import random |
| 7 | +from collections.abc import Awaitable, Callable |
7 | 8 | from datetime import date |
| 9 | +from time import monotonic |
| 10 | +from typing import Any, cast |
8 | 11 |
|
9 | | -from langchain_core.messages import ToolMessage, SystemMessage, AIMessage, HumanMessage |
10 | | -from langchain_core.runnables import RunnableConfig, RunnableLambda |
| 12 | +from langchain_core.messages import ToolMessage, SystemMessage, AIMessage, HumanMessage, AIMessageChunk |
| 13 | +from langchain_core.runnables import RunnableConfig |
11 | 14 | from nc_py_api import AsyncNextcloudApp |
12 | 15 | from nc_py_api.ex_app import persistent_storage |
13 | 16 |
|
@@ -94,7 +97,11 @@ def export_conversation(checkpointer): |
94 | 97 | conversation_token = add_signature(serialized_state.decode('utf-8'), key) |
95 | 98 | return conversation_token |
96 | 99 |
|
97 | | -async def react(task, nc: AsyncNextcloudApp): |
| 100 | +async def react( |
| 101 | + task, |
| 102 | + nc: AsyncNextcloudApp, |
| 103 | + stream_output: Callable[[dict[str, Any]], Awaitable[None]] | None = None, |
| 104 | +): |
98 | 105 | safe_tools, dangerous_tools = await get_tools(nc) |
99 | 106 |
|
100 | 107 | model.bind_nextcloud(nc) |
@@ -183,14 +190,64 @@ async def call_model( |
183 | 190 | else: |
184 | 191 | new_input = {"messages": [("user", task['input']['input'])]} |
185 | 192 |
|
186 | | - async for event in graph.astream(new_input, thread, stream_mode="values"): |
| 193 | + snapshot_messages = state_snapshot.values.get('messages', []) |
| 194 | + last_message: AIMessage = AIMessage("") |
| 195 | + if len(snapshot_messages) > 0: |
| 196 | + last_message = cast(AIMessage, snapshot_messages[-1]) |
| 197 | + source_list: list[str] = [] |
| 198 | + known_sources: set[str] = set() |
| 199 | + streamed_output = '' |
| 200 | + last_stream_update = 0.0 |
| 201 | + last_reported_stream_state: dict[str, Any] | None = None |
| 202 | + prefer_streaming = bool(task.get('preferStreaming')) |
| 203 | + stream_mode = ["messages", "values"] if prefer_streaming and stream_output is not None else "values" |
| 204 | + |
| 205 | + async def report_stream_state(force: bool = False): |
| 206 | + nonlocal last_stream_update |
| 207 | + nonlocal last_reported_stream_state |
| 208 | + if stream_output is None: |
| 209 | + return |
| 210 | + stream_state = {'output': streamed_output, 'sources': source_list.copy()} |
| 211 | + if last_reported_stream_state == stream_state: |
| 212 | + return |
| 213 | + now = monotonic() |
| 214 | + if not force and last_reported_stream_state is not None and (now - last_stream_update) < 0.5: |
| 215 | + return |
| 216 | + await stream_output(stream_state) |
| 217 | + last_stream_update = now |
| 218 | + last_reported_stream_state = stream_state |
| 219 | + |
| 220 | + async for event in graph.astream(new_input, thread, stream_mode=stream_mode): |
| 221 | + if isinstance(event, tuple): |
| 222 | + mode, payload = event |
| 223 | + else: |
| 224 | + mode, payload = "values", event |
| 225 | + |
| 226 | + if mode == 'messages': |
| 227 | + message_chunk, metadata = payload |
| 228 | + if metadata.get('langgraph_node') != 'agent' or not isinstance(message_chunk, AIMessageChunk): |
| 229 | + continue |
| 230 | + chunk_content = message_chunk.content |
| 231 | + if isinstance(chunk_content, str) and chunk_content != '': |
| 232 | + streamed_output += chunk_content |
| 233 | + await report_stream_state() |
| 234 | + continue |
| 235 | + |
| 236 | + event = payload |
187 | 237 | last_message = event['messages'][-1] |
188 | 238 | for message in event['messages']: |
189 | 239 | if isinstance(message, HumanMessage): |
190 | 240 | source_list = [] |
| 241 | + known_sources = set() |
191 | 242 | if isinstance(message, AIMessage) and message.tool_calls: |
192 | 243 | for tool_call in message.tool_calls: |
193 | | - source_list.append(tool_call['name']) |
| 244 | + tool_name = tool_call['name'] |
| 245 | + if tool_name not in known_sources: |
| 246 | + known_sources.add(tool_name) |
| 247 | + source_list.append(tool_name) |
| 248 | + await report_stream_state(force=True) |
| 249 | + |
| 250 | + await report_stream_state(force=True) |
194 | 251 |
|
195 | 252 | state_snapshot = graph.get_state(thread) |
196 | 253 | actions = '' |
|
0 commit comments