Skip to content

Commit 3b1449c

Browse files
authored
fix(cohere): support chat stream context managers (#418)
Delegate Cohere chat_stream enter/exit handling through the traced stream wrapper so native with-statement usage keeps working while spans still finish correctly. Add VCR regression coverage for ClientV2.chat_stream used as a context manager. resolves https://linear.app/braintrustdata/issue/BT-5179/braintrust-python-sdk-cohere-wrapper-chat-stream-doesnt-support-with
1 parent 55dafd0 commit 3b1449c

3 files changed

Lines changed: 174 additions & 0 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
interactions:
2+
- request:
3+
body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"Say
4+
hi in one word."}],"max_tokens":8,"stream":true}'
5+
headers:
6+
Accept:
7+
- '*/*'
8+
Accept-Encoding:
9+
- gzip, deflate
10+
Connection:
11+
- keep-alive
12+
Content-Length:
13+
- '119'
14+
Host:
15+
- api.cohere.com
16+
User-Agent:
17+
- cohere/6.1.0
18+
X-Fern-Language:
19+
- Python
20+
X-Fern-Platform:
21+
- darwin/25.2.0
22+
X-Fern-Runtime:
23+
- python/3.14.3
24+
X-Fern-SDK-Name:
25+
- cohere
26+
X-Fern-SDK-Version:
27+
- 6.1.0
28+
content-type:
29+
- application/json
30+
method: POST
31+
uri: https://api.cohere.com/v2/chat
32+
response:
33+
body:
34+
string: 'event: message-start
35+
36+
data: {"id":"b4256f33-1304-4943-b795-7f0d56d1fec9","type":"message-start","delta":{"message":{"role":"assistant","content":[],"tool_plan":"","tool_calls":[],"citations":[]}}}
37+
38+
39+
event: content-start
40+
41+
data: {"type":"content-start","index":0,"delta":{"message":{"content":{"type":"text","text":""}}}}
42+
43+
44+
event: content-delta
45+
46+
data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"Hi"}}}}
47+
48+
49+
event: content-delta
50+
51+
data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"!"}}}}
52+
53+
54+
event: content-end
55+
56+
data: {"type":"content-end","index":0}
57+
58+
59+
event: message-end
60+
61+
data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":6,"output_tokens":2},"tokens":{"input_tokens":501,"output_tokens":4},"cached_tokens":0}}}
62+
63+
64+
data: [DONE]
65+
66+
67+
'
68+
headers:
69+
Alt-Svc:
70+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
71+
Transfer-Encoding:
72+
- chunked
73+
Via:
74+
- 1.1 google
75+
access-control-expose-headers:
76+
- X-Debug-Trace-ID
77+
cache-control:
78+
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
79+
content-type:
80+
- text/event-stream
81+
date:
82+
- Thu, 16 Apr 2026 22:25:13 GMT
83+
expires:
84+
- Thu, 01 Jan 1970 00:00:00 GMT
85+
pragma:
86+
- no-cache
87+
server:
88+
- envoy
89+
vary:
90+
- Origin
91+
x-accel-expires:
92+
- '0'
93+
x-debug-trace-id:
94+
- 6ad8e9afe35e1e8627c7779791b16cc4
95+
x-endpoint-monthly-call-limit:
96+
- '1000'
97+
x-envoy-upstream-service-time:
98+
- '303'
99+
x-trial-endpoint-call-limit:
100+
- '20'
101+
x-trial-endpoint-call-remaining:
102+
- '17'
103+
status:
104+
code: 200
105+
message: OK
106+
version: 1

py/src/braintrust/integrations/cohere/test_cohere.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,34 @@ def test_wrap_cohere_chat_stream_v2_sync(memory_logger):
567567
assert metrics.get("completion_tokens", 0) > 0
568568

569569

570+
@pytest.mark.vcr
571+
def test_wrap_cohere_chat_stream_v2_sync_context_manager(memory_logger):
572+
assert not memory_logger.pop()
573+
client = wrap_cohere(_v2_client(require_methods=("chat_stream",)))
574+
575+
start = time.time()
576+
events = []
577+
with client.chat_stream(
578+
model=CHAT_MODEL,
579+
messages=[{"role": "user", "content": "Say hi in one word."}],
580+
max_tokens=8,
581+
) as stream:
582+
for event in stream:
583+
events.append(event)
584+
end = time.time()
585+
586+
assert events
587+
588+
spans = memory_logger.pop()
589+
assert len(spans) == 1
590+
span = spans[0]
591+
assert span["span_attributes"]["name"] == "cohere.chat_stream"
592+
assert span["metadata"]["provider"] == "cohere"
593+
assert span["metadata"]["model"] == CHAT_MODEL
594+
assert span["metrics"]["start"] <= span["metrics"]["end"]
595+
assert start <= span["metrics"]["start"] <= span["metrics"]["end"] <= end
596+
597+
570598
@pytest.mark.vcr
571599
def test_wrap_cohere_chat_stream_v2_rag_citations(memory_logger):
572600
if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest":

py/src/braintrust/integrations/cohere/tracing.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,26 @@ def _finish(self, error: BaseException | None = None) -> None:
840840
class _TracedChatStream(_ChatStreamTracker):
841841
"""Wrap a sync chat-stream iterator so exhaustion logs the aggregated span."""
842842

843+
def __enter__(self):
844+
context_manager = self._iterator
845+
enter = getattr(context_manager, "__enter__", None)
846+
if enter is not None:
847+
self._iterator = enter()
848+
self._context_manager = context_manager
849+
return self
850+
851+
def __exit__(self, exc_type, exc_value, traceback):
852+
suppress = False
853+
context_manager = getattr(self, "_context_manager", self._iterator)
854+
exit_method = getattr(context_manager, "__exit__", None)
855+
if exit_method is not None:
856+
suppress = bool(exit_method(exc_type, exc_value, traceback))
857+
if exc_value is not None:
858+
self._finish(error=exc_value)
859+
else:
860+
self._finish()
861+
return suppress
862+
843863
def __iter__(self):
844864
return self
845865

@@ -859,6 +879,26 @@ def __next__(self):
859879
class _AsyncTracedChatStream(_ChatStreamTracker):
860880
"""Async counterpart of :class:`_TracedChatStream`."""
861881

882+
async def __aenter__(self):
883+
context_manager = self._iterator
884+
aenter = getattr(context_manager, "__aenter__", None)
885+
if aenter is not None:
886+
self._iterator = await aenter()
887+
self._context_manager = context_manager
888+
return self
889+
890+
async def __aexit__(self, exc_type, exc_value, traceback):
891+
suppress = False
892+
context_manager = getattr(self, "_context_manager", self._iterator)
893+
aexit = getattr(context_manager, "__aexit__", None)
894+
if aexit is not None:
895+
suppress = bool(await aexit(exc_type, exc_value, traceback))
896+
if exc_value is not None:
897+
self._finish(error=exc_value)
898+
else:
899+
self._finish()
900+
return suppress
901+
862902
def __aiter__(self):
863903
return self
864904

0 commit comments

Comments
 (0)