Skip to content

Commit 8f5f852

Browse files
committed
Handle chat ProblemDetails SSE errors
1 parent c1f9a0a commit 8f5f852

3 files changed

Lines changed: 289 additions & 14 deletions

File tree

src/tests/test_chat_tool.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def mock_aiter_lines():
5858
]
5959

6060
assert result == "Hello world"
61+
assert call_args.kwargs["headers"]["Accept"] == "text/event-stream, application/problem+json"
6162
assert call_args.kwargs["headers"]["X-CodeAlive-Tool"] == "chat"
6263

6364

@@ -139,20 +140,122 @@ async def mock_aiter_lines():
139140
result = await chat(
140141
ctx=ctx,
141142
question="Follow up",
142-
conversation_id="conv_123"
143+
conversation_id="69fceb3e7b2a6a7efdd18180"
143144
)
144145

145146
call_args = mock_client.post.call_args
146147
request_data = call_args.kwargs["json"]
147148

148149
# Should include conversation ID
149-
assert request_data["conversationId"] == "conv_123"
150+
assert request_data["conversationId"] == "69fceb3e7b2a6a7efdd18180"
150151
# Should not have explicit names when continuing conversation
151152
assert "names" not in request_data
152-
153153
assert result == "Continued"
154154

155155

156+
@pytest.mark.asyncio
157+
@patch('tools.chat.get_api_key_from_context')
158+
async def test_chat_rejects_non_objectid_conversation_id(mock_get_api_key):
159+
"""Invalid continuation IDs fail locally with an actionable ToolError."""
160+
mock_get_api_key.return_value = "test_key"
161+
162+
ctx = MagicMock(spec=Context)
163+
ctx.info = AsyncMock()
164+
ctx.warning = AsyncMock()
165+
ctx.error = AsyncMock()
166+
167+
with pytest.raises(ToolError) as exc:
168+
await chat(
169+
ctx=ctx,
170+
question="Follow up",
171+
conversation_id="conv_123",
172+
)
173+
174+
msg = str(exc.value)
175+
assert "24-character hex Mongo ObjectId" in msg
176+
assert "Retry: no" in msg
177+
178+
179+
@pytest.mark.asyncio
180+
@patch('tools.chat.get_api_key_from_context')
181+
async def test_chat_named_sse_error_raises_tool_error(mock_get_api_key):
182+
"""RFC 9457 `event: error` frames must not collapse to an empty answer."""
183+
mock_get_api_key.return_value = "test_key"
184+
185+
ctx = MagicMock(spec=Context)
186+
ctx.info = AsyncMock()
187+
ctx.warning = AsyncMock()
188+
ctx.error = AsyncMock()
189+
190+
mock_response = MagicMock()
191+
mock_response.raise_for_status = MagicMock()
192+
193+
async def mock_aiter_lines():
194+
yield 'event: error'
195+
yield 'data: {"title":"Bad request","status":400,"detail":"Message content violates our content policy","requestId":"req-1"}'
196+
yield ''
197+
198+
mock_response.aiter_lines = mock_aiter_lines
199+
200+
mock_client = AsyncMock()
201+
mock_client.post.return_value = mock_response
202+
203+
mock_codealive_context = MagicMock()
204+
mock_codealive_context.client = mock_client
205+
mock_codealive_context.base_url = "https://app.codealive.ai"
206+
207+
ctx.request_context.lifespan_context = mock_codealive_context
208+
209+
with pytest.raises(ToolError) as exc:
210+
await chat(ctx=ctx, question="Test question", data_sources=["repo123"])
211+
212+
msg = str(exc.value)
213+
assert "Message content violates our content policy" in msg
214+
assert "Code: 400" in msg
215+
assert "Retry: no" in msg
216+
assert "requestId=req-1" in msg
217+
218+
219+
@pytest.mark.asyncio
220+
@patch('tools.chat.get_api_key_from_context')
221+
async def test_chat_named_sse_rate_limit_error_is_retryable(mock_get_api_key):
222+
"""429 ProblemDetails frames should tell agents to back off, not fix input."""
223+
mock_get_api_key.return_value = "test_key"
224+
225+
ctx = MagicMock(spec=Context)
226+
ctx.info = AsyncMock()
227+
ctx.warning = AsyncMock()
228+
ctx.error = AsyncMock()
229+
230+
mock_response = MagicMock()
231+
mock_response.raise_for_status = MagicMock()
232+
233+
async def mock_aiter_lines():
234+
yield 'event: error'
235+
yield 'data: {"title":"Plan limit","status":429,"detail":"Chat completion rate limit exceeded","requestId":"req-429"}'
236+
yield ''
237+
238+
mock_response.aiter_lines = mock_aiter_lines
239+
240+
mock_client = AsyncMock()
241+
mock_client.post.return_value = mock_response
242+
243+
mock_codealive_context = MagicMock()
244+
mock_codealive_context.client = mock_client
245+
mock_codealive_context.base_url = "https://app.codealive.ai"
246+
247+
ctx.request_context.lifespan_context = mock_codealive_context
248+
249+
with pytest.raises(ToolError) as exc:
250+
await chat(ctx=ctx, question="Test question", data_sources=["repo123"])
251+
252+
msg = str(exc.value)
253+
assert "Chat completion rate limit exceeded" in msg
254+
assert "Retry: yes" in msg
255+
assert "back off" in msg
256+
assert "requestId=req-429" in msg
257+
258+
156259
@pytest.mark.asyncio
157260
@patch('tools.chat.get_api_key_from_context')
158261
async def test_chat_empty_question_validation(mock_get_api_key):

src/tests/test_e2e_tools.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,11 @@ def handler(req):
988988

989989
class TestChatE2E:
990990
@staticmethod
991-
def _sse_body(chunks: list[str], conv_id: str = "conv-42", msg_id: str = "msg-1") -> str:
991+
def _sse_body(
992+
chunks: list[str],
993+
conv_id: str = "69fceb3e7b2a6a7efdd18180",
994+
msg_id: str = "69fceb3e7b2a6a7efdd18181",
995+
) -> str:
992996
"""Build an SSE response body with metadata + content chunks + DONE."""
993997
lines = [
994998
"event: message",
@@ -1011,6 +1015,7 @@ def handler(req):
10111015
data = json.loads(req.content)
10121016
assert data["stream"] is True
10131017
assert data["messages"][0]["content"] == "How does auth work?"
1018+
assert req.headers["accept"] == "text/event-stream, application/problem+json"
10141019
return httpx.Response(200, text=body, headers={"content-type": "text/event-stream"})
10151020

10161021
mcp = _server({"/api/chat/completions": handler})
@@ -1023,27 +1028,42 @@ def handler(req):
10231028
text = _text(result)
10241029
assert "Hello world!" in text
10251030
# New conversation gets ID appended
1026-
assert "conv-42" in text
1031+
assert "69fceb3e7b2a6a7efdd18180" in text
10271032

10281033
@pytest.mark.asyncio
10291034
async def test_continuing_conversation(self):
1030-
body = self._sse_body(["Follow-up answer"], conv_id="conv-existing")
1035+
conversation_id = "69fceb3e7b2a6a7efdd18180"
1036+
body = self._sse_body(["Follow-up answer"], conv_id=conversation_id)
10311037

10321038
def handler(req):
10331039
data = json.loads(req.content)
1034-
assert data["conversationId"] == "conv-existing"
1040+
assert data["conversationId"] == conversation_id
10351041
return httpx.Response(200, text=body, headers={"content-type": "text/event-stream"})
10361042

10371043
mcp = _server({"/api/chat/completions": handler})
10381044
async with Client(mcp) as client:
10391045
result = await client.call_tool(
10401046
"chat",
1041-
{"question": "And the error handling?", "conversation_id": "conv-existing"},
1047+
{"question": "And the error handling?", "conversation_id": conversation_id},
10421048
)
10431049

10441050
text = _text(result)
10451051
assert "Follow-up answer" in text
10461052

1053+
@pytest.mark.asyncio
1054+
async def test_invalid_conversation_id_returns_actionable_tool_error(self):
1055+
mcp = _server({})
1056+
async with Client(mcp) as client:
1057+
result = await client.call_tool(
1058+
"chat",
1059+
{"question": "And the error handling?", "conversation_id": "conv-existing"},
1060+
raise_on_error=False,
1061+
)
1062+
1063+
text = _text(result)
1064+
assert "24-character hex Mongo ObjectId" in text
1065+
assert "Retry: no" in text
1066+
10471067
@pytest.mark.asyncio
10481068
async def test_empty_question_returns_error(self):
10491069
mcp = _server({})
@@ -1071,6 +1091,97 @@ async def test_backend_error_handled(self):
10711091
text = _text(result)
10721092
assert "401" in text or "auth" in text.lower()
10731093

1094+
@pytest.mark.asyncio
1095+
async def test_problem_details_backend_error_keeps_detail_and_request_id(self):
1096+
problem = {
1097+
"type": "https://app.codealive.ai/errors/bad-request",
1098+
"title": "Bad request",
1099+
"status": 400,
1100+
"detail": "Message content violates our content policy",
1101+
"requestId": "req-rest",
1102+
}
1103+
1104+
mcp = _server({
1105+
"/api/chat/completions": lambda r: httpx.Response(
1106+
400,
1107+
json=problem,
1108+
headers={"content-type": "application/problem+json"},
1109+
),
1110+
})
1111+
async with Client(mcp) as client:
1112+
result = await client.call_tool(
1113+
"chat",
1114+
{"question": "hello"},
1115+
raise_on_error=False,
1116+
)
1117+
1118+
text = _text(result)
1119+
assert "Message content violates our content policy" in text
1120+
assert "requestId=req-rest" in text
1121+
assert "Retry: no" in text
1122+
1123+
@pytest.mark.asyncio
1124+
async def test_named_sse_problem_details_error_returns_tool_error(self):
1125+
problem = json.dumps({
1126+
"type": "https://app.codealive.ai/errors/bad-request",
1127+
"title": "Bad request",
1128+
"status": 400,
1129+
"detail": "Message content violates our content policy",
1130+
"requestId": "req-sse",
1131+
})
1132+
body = f"event: error\ndata: {problem}\n\n"
1133+
1134+
mcp = _server({
1135+
"/api/chat/completions": lambda r: httpx.Response(
1136+
200,
1137+
text=body,
1138+
headers={"content-type": "text/event-stream"},
1139+
),
1140+
})
1141+
async with Client(mcp) as client:
1142+
result = await client.call_tool(
1143+
"chat",
1144+
{"question": "hello", "data_sources": ["backend"]},
1145+
raise_on_error=False,
1146+
)
1147+
1148+
text = _text(result)
1149+
assert "Message content violates our content policy" in text
1150+
assert "Code: 400" in text
1151+
assert "requestId=req-sse" in text
1152+
assert "Retry: no" in text
1153+
1154+
@pytest.mark.asyncio
1155+
async def test_named_sse_rate_limit_error_is_retryable(self):
1156+
problem = json.dumps({
1157+
"type": "https://app.codealive.ai/errors/plan-limit",
1158+
"title": "Plan limit",
1159+
"status": 429,
1160+
"detail": "Chat completion rate limit exceeded",
1161+
"requestId": "req-sse-429",
1162+
})
1163+
body = f"event: error\ndata: {problem}\n\n"
1164+
1165+
mcp = _server({
1166+
"/api/chat/completions": lambda r: httpx.Response(
1167+
200,
1168+
text=body,
1169+
headers={"content-type": "text/event-stream"},
1170+
),
1171+
})
1172+
async with Client(mcp) as client:
1173+
result = await client.call_tool(
1174+
"chat",
1175+
{"question": "hello", "data_sources": ["backend"]},
1176+
raise_on_error=False,
1177+
)
1178+
1179+
text = _text(result)
1180+
assert "Chat completion rate limit exceeded" in text
1181+
assert "Retry: yes" in text
1182+
assert "back off" in text
1183+
assert "requestId=req-sse-429" in text
1184+
10741185
@pytest.mark.asyncio
10751186
async def test_unicode_preserved_in_streamed_response(self):
10761187
"""Cyrillic chunks streamed via SSE must survive as UTF-8 in the final text."""

0 commit comments

Comments
 (0)