Skip to content

Commit e6f11c4

Browse files
authored
perf(streaming): coalesce per-token publishes to Redis (50ms / 128-char window) (#333)
* perf(streaming): coalesce per-token publishes to Redis (50ms / 128-char window) Per-token Redis publishes from TemporalStreamingModel were adding ~45s (56-62%) overhead to agent response latency, mostly from head-of-line blocking on the model's event loop: each `await streaming_context.stream_update(...)` inside the OpenAI stream `async for` paused token consumption until the publish round-trip completed. This change introduces a `CoalescingBuffer` driven by an `asyncio.Event`, so the producer never awaits on Redis. Deltas are merged consecutive-only (preserving character order in every (type, index) channel) and flushed on a 50ms timer, on a 128-char size threshold, or immediately for the first delta to keep perceived responsiveness high. The buffer's `close()` drains remaining deltas before the DONE event, so consumers see the full sequence in order. A new `StreamingMode = Literal["off", "per_token", "coalesced"]` lives in `streaming.py` as the single source of truth and is plumbed through the adk streaming module, `StreamingService.streaming_task_message_context`, and `StreamingTaskMessageContext`. Default is `"coalesced"` everywhere, so all 13+ existing context callers (claude_agents, langgraph, litellm provider, openai sync provider, etc.) benefit automatically. * chore(streaming): fix import ordering (ruff I001) * fix(streaming): address greptile review findings - _run: when CancelledError is raised mid-flush in the for-loop, re-enqueue the in-flight item plus any remaining items in the local `drained` list back into self._buf so close()'s final drain can recover them. Previously the local `drained` list was unreachable after CancelledError exited the for-loop, causing the last coalesced batch to be silently dropped on close-during-flush races. Trade-off: the in-flight item may be duplicated on the consumer side (Redis pub may have completed before cancel was delivered), which is preferable to silent loss for streaming UX. - _merge_pair: replace `return b` fallback with AssertionError. All six current TaskMessageDelta variants have explicit isinstance branches, so the fallback is unreachable today. But _can_merge returns True for any same-type pair, so adding a 7th delta variant without updating _merge_pair would silently drop `a`'s accumulated content. Asserting turns a future silent data-loss into an immediate, diagnosable crash. * test(streaming): add coalescing-layer tests; loosen one model assertion After merging the test-suite repair from main (#334) into this branch, one model test (test_responses_api_streaming) regressed because its assert_called_with strict-matched all kwargs of streaming_task_message_context and didn't tolerate the new `streaming_mode='coalesced'` kwarg this PR adds. Switched to assert_called() + targeted kwarg checks so the test verifies what it cares about (task_id threading) without locking in implementation details. Replaced the ad-hoc smoke scripts that lived in conversation with a real pytest module at tests/lib/core/services/adk/test_streaming.py covering: - _delta_char_len, _can_merge, _merge_pair: per-channel correctness + None-handling - _merge_consecutive: pure-text collapse, cross-channel order preservation, per-channel reconstruction matches per-token semantics - CoalescingBuffer: first-delta-immediate flush within ~20ms, size-threshold flush before timer fires, multi-delta coalescing within one window, idle close, add-after-close no-op - CoalescingBuffer cancel-during-flush regression test for the P1 fix: five queued chunks must all surface across publishes when close() cancels mid-flush (asserts substring presence rather than exact ordering, since the documented trade-off allows duplicates of the in-flight item) - StreamingTaskMessageContext mode dispatch: "off" suppresses publishes but persists full content, "per_token" publishes each delta synchronously, "coalesced" batches and persists full content * chore(streaming): route TemporalStreamingModel logger through make_logger The model file used raw ``logging.getLogger("agentex.temporal.streaming")``, which returns a logger with no handler attached and no level configured — so the existing ``[TemporalStreamingModel] Initialized ... streaming_mode=...`` INFO log was silently dropped, making it impossible to verify at runtime that a coalesced (or any) streaming mode was actually wired. Switch to the SDK's ``make_logger`` helper (level=INFO, RichHandler in local mode, StreamHandler otherwise) used everywhere else in the SDK. The explicit logger name ``agentex.temporal.streaming`` is preserved so any external logging configuration targeting that name keeps working.
1 parent 9d80e0b commit e6f11c4

7 files changed

Lines changed: 782 additions & 18 deletions

File tree

src/agentex/lib/adk/_modules/streaming.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
88
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
99
from agentex.lib.core.services.adk.streaming import (
10+
StreamingMode,
1011
StreamingService,
1112
StreamingTaskMessageContext,
1213
)
@@ -50,6 +51,7 @@ def streaming_task_message_context(
5051
self,
5152
task_id: str,
5253
initial_content: TaskMessageContent,
54+
streaming_mode: StreamingMode = "coalesced",
5355
) -> StreamingTaskMessageContext:
5456
"""
5557
Create a streaming context for managing TaskMessage lifecycle.
@@ -60,7 +62,11 @@ def streaming_task_message_context(
6062
Args:
6163
task_id: The ID of the task
6264
initial_content: The initial content for the TaskMessage
63-
agentex_client: The agentex client for creating/updating messages
65+
streaming_mode: How per-delta updates are published. Defaults to
66+
"coalesced" (50ms / 128-char windowed batches with an immediate
67+
first-delta flush). Pass "per_token" for the legacy publish-every-
68+
delta behavior, or "off" to suppress per-delta publishes entirely
69+
while still recording the full message body on close.
6470
6571
Returns:
6672
StreamingTaskMessageContext: Context manager for streaming operations
@@ -76,4 +82,5 @@ def streaming_task_message_context(
7682
return self._streaming_service.streaming_task_message_context(
7783
task_id=task_id,
7884
initial_content=initial_content,
85+
streaming_mode=streaming_mode,
7986
)

src/agentex/lib/core/services/adk/streaming.py

Lines changed: 231 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import json
4-
from typing import Literal
4+
import asyncio
5+
import contextlib
6+
from typing import Literal, Callable, Awaitable
57

68
from agentex import AsyncAgentex
79
from agentex.lib.utils.logging import make_logger
@@ -39,6 +41,198 @@ def _get_stream_topic(task_id: str) -> str:
3941
return f"task:{task_id}"
4042

4143

44+
StreamingMode = Literal["off", "per_token", "coalesced"]
45+
"""Controls how a StreamingTaskMessageContext publishes deltas.
46+
47+
- "off": Feed the accumulator (so the persisted message body is correct)
48+
but never publish per-delta events. Consumers see start + done
49+
only. Lowest latency.
50+
- "per_token": Publish every delta immediately. Highest UX fidelity for
51+
token-by-token rendering, highest Redis cost, and re-introduces
52+
head-of-line blocking on the producer's event loop.
53+
- "coalesced": Buffer deltas in a small time/size window and publish them as
54+
merged batches. The first delta flushes immediately for fast
55+
perceived responsiveness; subsequent deltas flush every 50ms or
56+
whenever 128 buffered chars accumulate, whichever comes first.
57+
Order within each (delta type, index) channel is preserved
58+
exactly; only granularity changes.
59+
"""
60+
61+
62+
def _delta_char_len(delta: TaskMessageDelta | None) -> int:
63+
if delta is None:
64+
return 0
65+
if isinstance(delta, TextDelta):
66+
return len(delta.text_delta or "")
67+
if isinstance(delta, DataDelta):
68+
return len(delta.data_delta or "")
69+
if isinstance(delta, ReasoningSummaryDelta):
70+
return len(delta.summary_delta or "")
71+
if isinstance(delta, ReasoningContentDelta):
72+
return len(delta.content_delta or "")
73+
if isinstance(delta, ToolRequestDelta):
74+
return len(delta.arguments_delta or "")
75+
if isinstance(delta, ToolResponseDelta):
76+
return len(delta.content_delta or "")
77+
return 0
78+
79+
80+
def _can_merge(a: TaskMessageDelta, b: TaskMessageDelta) -> bool:
81+
if type(a) is not type(b):
82+
return False
83+
if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta):
84+
return a.summary_index == b.summary_index
85+
if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta):
86+
return a.content_index == b.content_index
87+
if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta):
88+
return a.tool_call_id == b.tool_call_id
89+
if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta):
90+
return a.tool_call_id == b.tool_call_id
91+
return True
92+
93+
94+
def _merge_pair(a: TaskMessageDelta, b: TaskMessageDelta) -> TaskMessageDelta:
95+
if isinstance(a, TextDelta) and isinstance(b, TextDelta):
96+
return TextDelta(type="text", text_delta=(a.text_delta or "") + (b.text_delta or ""))
97+
if isinstance(a, DataDelta) and isinstance(b, DataDelta):
98+
return DataDelta(type="data", data_delta=(a.data_delta or "") + (b.data_delta or ""))
99+
if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta):
100+
return ReasoningSummaryDelta(
101+
type="reasoning_summary",
102+
summary_index=a.summary_index,
103+
summary_delta=(a.summary_delta or "") + (b.summary_delta or ""),
104+
)
105+
if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta):
106+
return ReasoningContentDelta(
107+
type="reasoning_content",
108+
content_index=a.content_index,
109+
content_delta=(a.content_delta or "") + (b.content_delta or ""),
110+
)
111+
if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta):
112+
return ToolRequestDelta(
113+
type="tool_request",
114+
tool_call_id=a.tool_call_id,
115+
name=a.name,
116+
arguments_delta=(a.arguments_delta or "") + (b.arguments_delta or ""),
117+
)
118+
if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta):
119+
return ToolResponseDelta(
120+
type="tool_response",
121+
tool_call_id=a.tool_call_id,
122+
name=a.name,
123+
content_delta=(a.content_delta or "") + (b.content_delta or ""),
124+
)
125+
raise AssertionError(
126+
f"_can_merge approved {type(a).__name__} pair but _merge_pair has no handler — "
127+
"a new TaskMessageDelta variant was added without updating both functions"
128+
)
129+
130+
131+
def _merge_consecutive(updates: list[StreamTaskMessageDelta]) -> list[StreamTaskMessageDelta]:
132+
"""Merge consecutive same-channel deltas. Order across channels is preserved exactly."""
133+
result: list[StreamTaskMessageDelta] = []
134+
for u in updates:
135+
if u.delta is None or not result:
136+
result.append(u)
137+
continue
138+
last = result[-1]
139+
if last.delta is not None and _can_merge(last.delta, u.delta):
140+
result[-1] = StreamTaskMessageDelta(
141+
parent_task_message=last.parent_task_message,
142+
delta=_merge_pair(last.delta, u.delta),
143+
type="delta",
144+
)
145+
else:
146+
result.append(u)
147+
return result
148+
149+
150+
class CoalescingBuffer:
151+
"""Time-and-size-windowed buffer that merges consecutive same-channel deltas.
152+
153+
Decouples the producer (model event loop) from the publisher (Redis): ``add``
154+
only enqueues and may signal an early flush; the actual publish always runs
155+
on a background ticker, so the producer never awaits on a Redis round-trip.
156+
"""
157+
158+
FLUSH_INTERVAL_S = 0.050
159+
MAX_BUFFERED_CHARS = 128
160+
161+
def __init__(self, on_flush: Callable[[StreamTaskMessageDelta], Awaitable[object]]):
162+
self._on_flush = on_flush
163+
self._buf: list[StreamTaskMessageDelta] = []
164+
self._buf_chars = 0
165+
self._first_flushed = False
166+
self._closed = False
167+
self._lock = asyncio.Lock()
168+
self._flush_signal = asyncio.Event()
169+
self._task: asyncio.Task[None] | None = None
170+
171+
def start(self) -> None:
172+
if self._task is None:
173+
self._task = asyncio.create_task(self._run(), name="coalescing-buffer")
174+
175+
async def add(self, update: StreamTaskMessageDelta) -> None:
176+
if self._closed:
177+
return
178+
async with self._lock:
179+
self._buf.append(update)
180+
self._buf_chars += _delta_char_len(update.delta)
181+
if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS:
182+
self._first_flushed = True
183+
self._flush_signal.set()
184+
185+
async def _run(self) -> None:
186+
try:
187+
while not self._closed:
188+
try:
189+
await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S)
190+
except asyncio.TimeoutError:
191+
pass
192+
async with self._lock:
193+
self._flush_signal.clear()
194+
drained = self._drain_locked()
195+
for idx, u in enumerate(drained):
196+
try:
197+
await self._on_flush(u)
198+
except asyncio.CancelledError:
199+
# Re-enqueue the item being flushed plus any remaining so
200+
# close()'s final drain can recover them. May cause a
201+
# duplicate publish of the in-flight item, which is
202+
# preferable to silent loss for a streaming UX.
203+
async with self._lock:
204+
self._buf = drained[idx:] + self._buf
205+
raise
206+
except Exception as e:
207+
logger.exception(f"CoalescingBuffer flush failed: {e}")
208+
except asyncio.CancelledError:
209+
pass
210+
211+
async def close(self) -> None:
212+
self._closed = True
213+
if self._task is not None:
214+
self._flush_signal.set()
215+
self._task.cancel()
216+
with contextlib.suppress(asyncio.CancelledError):
217+
await self._task
218+
self._task = None
219+
async with self._lock:
220+
drained = self._drain_locked()
221+
for u in drained:
222+
try:
223+
await self._on_flush(u)
224+
except Exception as e:
225+
logger.exception(f"CoalescingBuffer final flush failed: {e}")
226+
227+
def _drain_locked(self) -> list[StreamTaskMessageDelta]:
228+
if not self._buf:
229+
return []
230+
merged = _merge_consecutive(self._buf)
231+
self._buf = []
232+
self._buf_chars = 0
233+
return merged
234+
235+
42236
class DeltaAccumulator:
43237
def __init__(self):
44238
self._accumulated_deltas: list[TaskMessageDelta] = []
@@ -176,6 +370,7 @@ def __init__(
176370
initial_content: TaskMessageContent,
177371
agentex_client: AsyncAgentex,
178372
streaming_service: "StreamingService",
373+
streaming_mode: StreamingMode = "coalesced",
179374
):
180375
self.task_id = task_id
181376
self.initial_content = initial_content
@@ -184,6 +379,8 @@ def __init__(
184379
self._streaming_service = streaming_service
185380
self._is_closed = False
186381
self._delta_accumulator = DeltaAccumulator()
382+
self._streaming_mode: StreamingMode = streaming_mode
383+
self._buffer: CoalescingBuffer | None = None
187384

188385
async def __aenter__(self) -> "StreamingTaskMessageContext":
189386
return await self.open()
@@ -208,6 +405,10 @@ async def open(self) -> "StreamingTaskMessageContext":
208405
)
209406
await self._streaming_service.stream_update(start_event)
210407

408+
if self._streaming_mode == "coalesced":
409+
self._buffer = CoalescingBuffer(on_flush=self._streaming_service.stream_update)
410+
self._buffer.start()
411+
211412
return self
212413

213414
async def close(self) -> TaskMessage:
@@ -218,6 +419,12 @@ async def close(self) -> TaskMessage:
218419
if self._is_closed:
219420
return self.task_message # Already done
220421

422+
# Drain any buffered deltas before announcing DONE so consumers see the
423+
# full sequence in order.
424+
if self._buffer is not None:
425+
await self._buffer.close()
426+
self._buffer = None
427+
221428
# Send the DONE event
222429
done_event = StreamTaskMessageDone(
223430
parent_task_message=self.task_message,
@@ -227,8 +434,8 @@ async def close(self) -> TaskMessage:
227434

228435
# Update the task message with the final content
229436
has_deltas = (
230-
self._delta_accumulator._accumulated_deltas or
231-
self._delta_accumulator._reasoning_summaries or
437+
self._delta_accumulator._accumulated_deltas or
438+
self._delta_accumulator._reasoning_summaries or
232439
self._delta_accumulator._reasoning_contents
233440
)
234441
if has_deltas:
@@ -248,7 +455,20 @@ async def close(self) -> TaskMessage:
248455
async def stream_update(
249456
self, update: TaskMessageUpdate
250457
) -> TaskMessageUpdate | None:
251-
"""Stream an update to the repository."""
458+
"""Stream an update to the repository.
459+
460+
Behavior depends on the context's ``streaming_mode``:
461+
- "off": delta updates feed the accumulator (so the persisted message
462+
body is correct) but are never published.
463+
- "per_token": delta updates are published immediately.
464+
- "coalesced": delta updates are queued in a 50ms / 128-char window and
465+
flushed as merged batches on a background ticker; the first delta
466+
flushes immediately for fast perceived responsiveness.
467+
468+
``StreamTaskMessageDone`` and ``StreamTaskMessageFull`` updates always
469+
publish synchronously regardless of mode so consumers and persistence
470+
stay in sync.
471+
"""
252472
if self._is_closed:
253473
raise ValueError("Context is already done")
254474

@@ -258,6 +478,11 @@ async def stream_update(
258478
if isinstance(update, StreamTaskMessageDelta):
259479
if update.delta is not None:
260480
self._delta_accumulator.add_delta(update.delta)
481+
if self._streaming_mode == "off":
482+
return update
483+
if self._streaming_mode == "coalesced" and self._buffer is not None:
484+
await self._buffer.add(update)
485+
return update
261486

262487
result = await self._streaming_service.stream_update(update)
263488

@@ -288,12 +513,14 @@ def streaming_task_message_context(
288513
self,
289514
task_id: str,
290515
initial_content: TaskMessageContent,
516+
streaming_mode: StreamingMode = "coalesced",
291517
) -> StreamingTaskMessageContext:
292518
return StreamingTaskMessageContext(
293519
task_id=task_id,
294520
initial_content=initial_content,
295521
agentex_client=self._agentex_client,
296522
streaming_service=self,
523+
streaming_mode=streaming_mode,
297524
)
298525

299526
async def stream_update(

0 commit comments

Comments
 (0)