Skip to content

Commit 19c3efe

Browse files
committed
feat: Stream output text and tool calls
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
1 parent fa4f978 commit 19c3efe

3 files changed

Lines changed: 294 additions & 51 deletions

File tree

ex_app/lib/agent.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import os
55
import string
66
import random
7+
from collections.abc import Awaitable, Callable
78
from datetime import date
9+
from time import monotonic
10+
from typing import Any, cast
811

9-
from langchain_core.messages import ToolMessage, SystemMessage, AIMessage, HumanMessage
10-
from langchain_core.runnables import RunnableConfig, RunnableLambda
12+
from langchain_core.messages import ToolMessage, SystemMessage, AIMessage, HumanMessage, AIMessageChunk
13+
from langchain_core.runnables import RunnableConfig
1114
from nc_py_api import AsyncNextcloudApp
1215
from nc_py_api.ex_app import persistent_storage
1316

@@ -94,7 +97,11 @@ def export_conversation(checkpointer):
9497
conversation_token = add_signature(serialized_state.decode('utf-8'), key)
9598
return conversation_token
9699

97-
async def react(task, nc: AsyncNextcloudApp):
100+
async def react(
101+
task,
102+
nc: AsyncNextcloudApp,
103+
stream_output: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
104+
):
98105
safe_tools, dangerous_tools = await get_tools(nc)
99106

100107
model.bind_nextcloud(nc)
@@ -183,14 +190,64 @@ async def call_model(
183190
else:
184191
new_input = {"messages": [("user", task['input']['input'])]}
185192

186-
async for event in graph.astream(new_input, thread, stream_mode="values"):
193+
snapshot_messages = state_snapshot.values.get('messages', [])
194+
last_message: AIMessage = AIMessage("")
195+
if len(snapshot_messages) > 0:
196+
last_message = cast(AIMessage, snapshot_messages[-1])
197+
source_list: list[str] = []
198+
known_sources: set[str] = set()
199+
streamed_output = ''
200+
last_stream_update = 0.0
201+
last_reported_stream_state: dict[str, Any] | None = None
202+
prefer_streaming = bool(task.get('preferStreaming'))
203+
stream_mode = ["messages", "values"] if prefer_streaming and stream_output is not None else "values"
204+
205+
async def report_stream_state(force: bool = False):
206+
nonlocal last_stream_update
207+
nonlocal last_reported_stream_state
208+
if stream_output is None:
209+
return
210+
stream_state = {'output': streamed_output, 'sources': source_list.copy()}
211+
if last_reported_stream_state == stream_state:
212+
return
213+
now = monotonic()
214+
if not force and last_reported_stream_state is not None and (now - last_stream_update) < 0.5:
215+
return
216+
await stream_output(stream_state)
217+
last_stream_update = now
218+
last_reported_stream_state = stream_state
219+
220+
async for event in graph.astream(new_input, thread, stream_mode=stream_mode):
221+
if isinstance(event, tuple):
222+
mode, payload = event
223+
else:
224+
mode, payload = "values", event
225+
226+
if mode == 'messages':
227+
message_chunk, metadata = payload
228+
if metadata.get('langgraph_node') != 'agent' or not isinstance(message_chunk, AIMessageChunk):
229+
continue
230+
chunk_content = message_chunk.content
231+
if isinstance(chunk_content, str) and chunk_content != '':
232+
streamed_output += chunk_content
233+
await report_stream_state()
234+
continue
235+
236+
event = payload
187237
last_message = event['messages'][-1]
188238
for message in event['messages']:
189239
if isinstance(message, HumanMessage):
190240
source_list = []
241+
known_sources = set()
191242
if isinstance(message, AIMessage) and message.tool_calls:
192243
for tool_call in message.tool_calls:
193-
source_list.append(tool_call['name'])
244+
tool_name = tool_call['name']
245+
if tool_name not in known_sources:
246+
known_sources.add(tool_name)
247+
source_list.append(tool_name)
248+
await report_stream_state(force=True)
249+
250+
await report_stream_state(force=True)
194251

195252
state_snapshot = graph.get_state(thread)
196253
actions = ''

ex_app/lib/main.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,26 @@ async def handle_task(task, nc: AsyncNextcloudApp):
176176
nextcloud = AsyncNextcloudApp()
177177
if task['userId']:
178178
await nextcloud.set_user(task['userId'])
179-
output = await react(task, nextcloud)
179+
180+
stream_updates_enabled = task.get('preferStreaming', None) is True
181+
stream_update_failed = False
182+
183+
async def stream_output(intermediate_output):
184+
nonlocal stream_update_failed
185+
if not stream_updates_enabled or stream_update_failed:
186+
return
187+
try:
188+
await nc.ocs(
189+
"POST",
190+
f"/ocs/v2.php/taskprocessing/tasks_provider/{task['id']}/stream-result",
191+
json={"output": intermediate_output},
192+
)
193+
except (NextcloudException, RequestException) as stream_err:
194+
stream_update_failed = True
195+
tb_str = ''.join(traceback.format_exception(stream_err))
196+
await log(nc, LogLvl.WARNING, "Error streaming intermediate task result: " + tb_str)
197+
198+
output = await react(task, nextcloud, stream_output=stream_output if stream_updates_enabled else None)
180199
except Exception as e: # noqa
181200
try:
182201
tb_str = ''.join(traceback.format_exception(e))

0 commit comments

Comments
 (0)