Skip to content

Commit 843de55

Browse files
authored
feat(adapters): abstract streaming layer behind LLMProvider (#548)
## Summary - Add `StreamChunk` dataclass as a normalized streaming type across all providers - Add `async_stream()` and `supports()` to `LLMProvider` ABC - `AnthropicProvider.async_stream()`: translates SDK events → StreamChunk; awaits `get_final_message()`; supports extended thinking via betas header - `OpenAIProvider.async_stream()`: SSE → StreamChunk; deferred tool_use_start; exception handling; JSON decode warnings - `MockProvider.async_stream()`: configurable chunks + call tracking - `StreamingChatAdapter`: no `import anthropic`; passes `extended_thinking` flag via `supports()`; workspace path removed from system prompt - `session_chat_ws.py`: explicit provider construction at composition root - Fix e2b tests to skip when optional `e2b` package not installed ## Validation - Review feedback: All addressed (claude-review + 2 CodeRabbit rounds, 3 commits) - Demo: All 5 acceptance criteria verified via Showboat - Tests: 87 passed (0 failures) - CI: All checks green - Linting: Clean Closes #548
1 parent 21785ae commit 843de55

10 files changed

Lines changed: 840 additions & 440 deletions

File tree

codeframe/adapters/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Message,
2727
ModelSelector,
2828
Purpose,
29+
StreamChunk,
2930
Tool,
3031
ToolCall,
3132
ToolResult,
@@ -40,6 +41,7 @@
4041
"Message",
4142
"ModelSelector",
4243
"Purpose",
44+
"StreamChunk",
4345
"Tool",
4446
"ToolCall",
4547
"ToolResult",

codeframe/adapters/llm/anthropic.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
Provides Claude model access via the Anthropic API.
44
"""
55

6+
import asyncio
67
import os
7-
from typing import TYPE_CHECKING, Iterator, Optional
8+
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
89

910
from codeframe.adapters.llm.base import (
1011
LLMProvider,
1112
LLMResponse,
1213
ModelSelector,
1314
Purpose,
15+
StreamChunk,
1416
Tool,
1517
ToolCall,
1618
)
@@ -172,6 +174,117 @@ async def async_complete(
172174
except APIConnectionError as exc:
173175
raise LLMConnectionError(str(exc)) from exc
174176

177+
def supports(self, capability: str) -> bool:
178+
"""Return True for capabilities this provider supports."""
179+
return capability == "extended_thinking"
180+
181+
async def async_stream(
182+
self,
183+
messages: list[dict],
184+
system: str,
185+
tools: list[dict],
186+
model: str,
187+
max_tokens: int,
188+
interrupt_event: Optional[asyncio.Event] = None,
189+
extended_thinking: bool = False,
190+
) -> AsyncIterator[StreamChunk]:
191+
"""Stream using Anthropic AsyncAnthropic SDK, yielding StreamChunk objects.
192+
193+
Translates Anthropic SDK events into the normalized StreamChunk format.
194+
Tool inputs are collected and emitted in the final message_stop chunk
195+
via tool_inputs_by_id, which is more reliable than streaming input deltas.
196+
197+
When ``extended_thinking=True``, requests interleaved thinking via the
198+
Anthropic betas API. The flag is silently ignored on SDK versions that
199+
do not support it.
200+
"""
201+
from anthropic import AsyncAnthropic
202+
203+
if self._async_client is None:
204+
self._async_client = AsyncAnthropic(api_key=self.api_key)
205+
206+
# Convert messages to Anthropic API format (handles tool_calls/tool_results)
207+
converted = self._convert_messages(messages)
208+
209+
kwargs: dict = {
210+
"model": model,
211+
"system": system,
212+
"messages": converted,
213+
"tools": tools,
214+
"max_tokens": max_tokens,
215+
}
216+
217+
if extended_thinking:
218+
kwargs["betas"] = ["interleaved-thinking-2025-05-14"]
219+
220+
active_tool_id: Optional[str] = None
221+
222+
# When extended_thinking is set, the beta header may be unsupported on
223+
# older SDK versions. Retry without it rather than hard-failing.
224+
try:
225+
stream_ctx = self._async_client.messages.stream(**kwargs)
226+
except Exception: # pragma: no cover
227+
if extended_thinking:
228+
kwargs.pop("betas", None)
229+
stream_ctx = self._async_client.messages.stream(**kwargs)
230+
else:
231+
raise
232+
233+
async with stream_ctx as stream:
234+
async for sdk_event in stream:
235+
if interrupt_event and interrupt_event.is_set():
236+
return
237+
238+
event_type = sdk_event.type
239+
240+
if event_type == "content_block_start":
241+
block = sdk_event.content_block
242+
if block.type == "tool_use":
243+
active_tool_id = block.id
244+
yield StreamChunk(
245+
type="tool_use_start",
246+
tool_id=block.id,
247+
tool_name=block.name,
248+
tool_input=getattr(block, "input", {}),
249+
)
250+
251+
elif event_type == "content_block_delta":
252+
delta = sdk_event.delta
253+
if delta.type == "text_delta":
254+
yield StreamChunk(type="text_delta", text=delta.text)
255+
elif delta.type == "thinking_delta":
256+
yield StreamChunk(type="thinking_delta", text=delta.thinking)
257+
# input_json_delta: final inputs are rebuilt from message_stop
258+
259+
elif event_type == "content_block_stop":
260+
if active_tool_id is not None:
261+
yield StreamChunk(type="tool_use_stop")
262+
active_tool_id = None
263+
264+
elif event_type == "message_stop":
265+
# Flush any open tool block
266+
if active_tool_id is not None:
267+
yield StreamChunk(type="tool_use_stop")
268+
active_tool_id = None
269+
270+
final_msg = await stream.get_final_message()
271+
stop_reason = final_msg.stop_reason or "end_turn"
272+
273+
# Build tool_inputs_by_id from final content blocks
274+
tool_inputs_by_id: dict = {}
275+
if hasattr(final_msg, "content"):
276+
for block in final_msg.content:
277+
if getattr(block, "type", None) == "tool_use" and hasattr(block, "id"):
278+
tool_inputs_by_id[block.id] = getattr(block, "input", {})
279+
280+
yield StreamChunk(
281+
type="message_stop",
282+
stop_reason=stop_reason,
283+
input_tokens=final_msg.usage.input_tokens,
284+
output_tokens=final_msg.usage.output_tokens,
285+
tool_inputs_by_id=tool_inputs_by_id,
286+
)
287+
175288
def stream(
176289
self,
177290
messages: list[dict],

codeframe/adapters/llm/base.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from dataclasses import dataclass, field
1111
from enum import Enum
12-
from typing import Iterator, Optional
12+
from typing import AsyncIterator, Iterator, Optional
1313

1414

1515
# ---------------------------------------------------------------------------
@@ -120,6 +120,53 @@ def for_purpose(self, purpose: Purpose) -> str:
120120
return self.execution_model # Default fallback
121121

122122

123+
@dataclass
124+
class StreamChunk:
125+
"""A normalized chunk from a streaming LLM response.
126+
127+
Provider-specific streaming formats are translated into this common type
128+
by each :class:`LLMProvider` implementation.
129+
130+
Attributes:
131+
type: Event type — one of ``"text_delta"``, ``"thinking_delta"``,
132+
``"tool_use_start"``, ``"tool_use_stop"``, ``"message_stop"``.
133+
text: Text content for ``text_delta`` and ``thinking_delta`` types.
134+
tool_id: Tool call ID for ``tool_use_start``.
135+
tool_name: Tool name for ``tool_use_start``.
136+
tool_input: Tool input dict for ``tool_use_start`` (may be empty;
137+
final inputs are provided in the ``message_stop`` chunk).
138+
input_tokens: Input token count, populated for ``message_stop``.
139+
output_tokens: Output token count, populated for ``message_stop``.
140+
stop_reason: Why the model stopped, populated for ``message_stop``.
141+
tool_inputs_by_id: Mapping of tool_id → final input dict, populated
142+
for ``message_stop``. More reliable than streaming incremental
143+
input deltas.
144+
145+
.. note:: ``tool_use_stop`` ordering differs by provider:
146+
147+
- **Anthropic**: emitted immediately when each tool call's content
148+
block ends (``content_block_stop`` event), so consumers see
149+
``tool_use_start → [deltas] → tool_use_stop`` interleaved.
150+
- **OpenAI-compatible**: emitted after the full stream ends (before
151+
``message_stop``), because the SSE protocol has no per-tool stop
152+
marker. All ``tool_use_stop`` chunks arrive together at the end.
153+
154+
Consumers MUST use ``tool_inputs_by_id`` from the ``message_stop``
155+
chunk for final tool inputs rather than relying on ``tool_use_stop``
156+
ordering.
157+
"""
158+
159+
type: str
160+
text: Optional[str] = None
161+
tool_id: Optional[str] = None
162+
tool_name: Optional[str] = None
163+
tool_input: Optional[dict] = None
164+
input_tokens: Optional[int] = None
165+
output_tokens: Optional[int] = None
166+
stop_reason: Optional[str] = None
167+
tool_inputs_by_id: Optional[dict] = None
168+
169+
123170
@dataclass
124171
class ToolCall:
125172
"""Represents a tool call requested by the LLM.
@@ -332,6 +379,59 @@ async def async_complete(
332379
lambda: self.complete(messages, purpose, tools, max_tokens, temperature, system),
333380
)
334381

382+
def supports(self, capability: str) -> bool:
383+
"""Check whether this provider supports an optional capability.
384+
385+
Args:
386+
capability: Capability name, e.g. ``"extended_thinking"``.
387+
388+
Returns:
389+
``True`` if the capability is supported, ``False`` otherwise.
390+
"""
391+
return False
392+
393+
# Not decorated with @abstractmethod intentionally: providers that only
394+
# support synchronous completion (e.g. thin wrappers) don't need to
395+
# implement streaming. Calling async_stream() on such a provider raises
396+
# NotImplementedError at call time rather than at instantiation.
397+
async def async_stream(
398+
self,
399+
messages: list[dict],
400+
system: str,
401+
tools: list[dict],
402+
model: str,
403+
max_tokens: int,
404+
interrupt_event: Optional[asyncio.Event] = None,
405+
extended_thinking: bool = False,
406+
) -> AsyncIterator["StreamChunk"]:
407+
"""Stream a completion as normalized :class:`StreamChunk` objects.
408+
409+
Subclasses should override this with a provider-specific implementation.
410+
The default raises :exc:`NotImplementedError`.
411+
412+
Args:
413+
messages: Conversation messages in the provider's expected format.
414+
system: System prompt string.
415+
tools: Already-serialized tool definitions (list of dicts).
416+
model: Model identifier to use for this call.
417+
max_tokens: Maximum output tokens.
418+
interrupt_event: When set, the stream should stop at the next
419+
opportunity.
420+
extended_thinking: When ``True``, request extended thinking tokens
421+
from providers that support them (see :meth:`supports`).
422+
Providers that do not support this capability should silently
423+
ignore the flag.
424+
425+
Yields:
426+
:class:`StreamChunk` objects in order of generation.
427+
"""
428+
raise NotImplementedError(
429+
f"{type(self).__name__} does not implement async_stream(). "
430+
"Override this method in your provider subclass."
431+
)
432+
if False: # pragma: no cover # makes this an async generator
433+
yield # type: ignore[misc]
434+
335435
def get_model(self, purpose: Purpose) -> str:
336436
"""Get the model for a given purpose.
337437

codeframe/adapters/llm/mock.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
Supports configurable responses and call tracking.
55
"""
66

7-
from typing import Callable, Iterator, Optional
7+
import asyncio
8+
from typing import AsyncIterator, Callable, Iterator, Optional
89

910
from codeframe.adapters.llm.base import (
1011
LLMProvider,
1112
LLMResponse,
1213
ModelSelector,
1314
Purpose,
15+
StreamChunk,
1416
Tool,
1517
ToolCall,
1618
)
@@ -40,6 +42,8 @@ def __init__(
4042
self.responses: list[LLMResponse] = []
4143
self.response_index = 0
4244
self.response_handler: Optional[Callable[[list[dict]], LLMResponse]] = None
45+
self.stream_chunks: list[list[StreamChunk]] = []
46+
self.stream_index = 0
4347

4448
def add_response(self, response: LLMResponse) -> None:
4549
"""Add a canned response to the queue.
@@ -175,12 +179,83 @@ def stream(
175179
for word in response.content.split():
176180
yield word + " "
177181

182+
def add_stream_chunks(self, chunks: list[StreamChunk]) -> None:
183+
"""Add a sequence of StreamChunks for the next async_stream() call.
184+
185+
Args:
186+
chunks: Ordered list of StreamChunk objects to yield.
187+
"""
188+
self.stream_chunks.append(chunks)
189+
190+
async def async_stream(
191+
self,
192+
messages: list[dict],
193+
system: str,
194+
tools: list[dict],
195+
model: str,
196+
max_tokens: int,
197+
interrupt_event: Optional[asyncio.Event] = None,
198+
extended_thinking: bool = False,
199+
) -> AsyncIterator[StreamChunk]:
200+
"""Yield pre-configured StreamChunk sequences for testing.
201+
202+
Tracks each call in :attr:`calls` (same metadata as :meth:`complete`).
203+
When pre-configured ``stream_chunks`` are available, yields them in
204+
order. Otherwise falls back to a minimal ``text_delta`` +
205+
``message_stop`` pair derived from the normal response queue
206+
(``responses`` / ``response_handler`` / ``default_response``).
207+
"""
208+
# Track the call so tests can assert on it
209+
self.calls.append(
210+
{
211+
"messages": messages,
212+
"system": system,
213+
"tools": tools,
214+
"model": model,
215+
"max_tokens": max_tokens,
216+
"extended_thinking": extended_thinking,
217+
}
218+
)
219+
220+
if self.stream_index < len(self.stream_chunks):
221+
chunks = self.stream_chunks[self.stream_index]
222+
self.stream_index += 1
223+
else:
224+
# Derive response text from the normal queue / handler
225+
if self.response_handler:
226+
resp = self.response_handler(messages)
227+
text = resp.content
228+
elif self.response_index < len(self.responses):
229+
resp = self.responses[self.response_index]
230+
self.response_index += 1
231+
text = resp.content
232+
else:
233+
text = self.default_response
234+
235+
chunks = [
236+
StreamChunk(type="text_delta", text=text),
237+
StreamChunk(
238+
type="message_stop",
239+
stop_reason="end_turn",
240+
input_tokens=len(str(messages)),
241+
output_tokens=len(text),
242+
tool_inputs_by_id={},
243+
),
244+
]
245+
246+
for chunk in chunks:
247+
if interrupt_event and interrupt_event.is_set():
248+
return
249+
yield chunk
250+
178251
def reset(self) -> None:
179252
"""Reset call tracking and response queue."""
180253
self.calls.clear()
181254
self.responses.clear()
182255
self.response_index = 0
183256
self.response_handler = None
257+
self.stream_chunks.clear()
258+
self.stream_index = 0
184259

185260
@property
186261
def call_count(self) -> int:

0 commit comments

Comments
 (0)