Skip to content

Commit 8cccbe3

Browse files
add model fallback
1 parent 0f7dcd2 commit 8cccbe3

6 files changed

Lines changed: 545 additions & 121 deletions

File tree

services/doc_agent_chat/agent.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from doc_agent_chat.prompt import build_system_prompt
88
from doc_agent_chat.tools import TOOL_DEFINITIONS, search_documents, format_search_results_as_documents
99
from doc_agent_chat.config_loader import ConfigLoader
10-
from models import preferred_chat_model
10+
from models import preferred_chat_model, call_with_model_fallback
1111

1212
logger = create_logger("agent")
1313

@@ -59,16 +59,19 @@ def run(
5959
for iteration in range(self.max_tool_calls):
6060
logger.info(f"Agentic loop iteration {iteration + 1}")
6161

62-
response = self.client.messages.create(
63-
model=self.model,
64-
max_tokens=self.max_tokens,
65-
system=system_prompt,
66-
messages=messages,
67-
tools=TOOL_DEFINITIONS,
68-
# Per-request timeout (same values as the SDK default):
69-
# required for non-streaming calls with max_tokens > ~21k,
70-
# which the SDK otherwise rejects.
71-
timeout=httpx.Timeout(600.0, connect=5.0),
62+
response = call_with_model_fallback(
63+
lambda m: self.client.messages.create(
64+
model=m,
65+
max_tokens=self.max_tokens,
66+
system=system_prompt,
67+
messages=messages,
68+
tools=TOOL_DEFINITIONS,
69+
# Per-request timeout (same values as the SDK default):
70+
# required for non-streaming calls with max_tokens > ~21k,
71+
# which the SDK otherwise rejects.
72+
timeout=httpx.Timeout(600.0, connect=5.0),
73+
),
74+
preferred=self.model,
7275
)
7376

7477
if hasattr(response, "usage"):

services/global_chat/planner.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
STATUS_PLANNING,
2525
)
2626
from global_chat.config_loader import ConfigLoader
27-
from models import preferred_chat_model
27+
from models import (
28+
preferred_chat_model,
29+
call_with_model_fallback,
30+
stream_with_model_fallback,
31+
)
2832
from global_chat.tools.tool_definitions import TOOL_DEFINITIONS
2933
from global_chat.yaml_utils import stitch_job_code, redact_job_bodies, find_job_in_yaml
3034
from tools.search_documentation.search_documentation import search_documentation_tool
@@ -276,47 +280,56 @@ def _call_api(self, system_prompt, messages, stream):
276280
task-specific status messages sent before each tool execution.
277281
"""
278282
if stream:
279-
buffered_text = []
280-
281-
with self.client.messages.stream(
282-
model=self.model,
283-
max_tokens=self.max_tokens,
284-
system=system_prompt,
285-
messages=messages,
286-
tools=self.tools,
287-
thinking={"type": "adaptive"},
288-
output_config={"effort": "medium"},
289-
) as stream_obj:
283+
def _consume(stream_obj, commit):
284+
buffered_text = []
290285
for event in stream_obj:
291286
if event.type == "content_block_delta":
292287
if event.delta.type == "text_delta":
293288
buffered_text.append(event.delta.text)
289+
commit()
294290
return stream_obj.get_final_message(), buffered_text
291+
292+
return stream_with_model_fallback(
293+
lambda m: self.client.messages.stream(
294+
model=m,
295+
max_tokens=self.max_tokens,
296+
system=system_prompt,
297+
messages=messages,
298+
tools=self.tools,
299+
thinking={"type": "adaptive"},
300+
output_config={"effort": "medium"},
301+
),
302+
_consume,
303+
preferred=self.model,
304+
)
295305
else:
296-
response = self.client.beta.messages.create(
297-
model=self.model,
298-
max_tokens=self.max_tokens,
299-
system=system_prompt,
300-
messages=messages,
301-
tools=self.tools,
302-
thinking={"type": "adaptive"},
303-
output_config={"effort": "medium"},
304-
# Per-request timeout (same values as the SDK default):
305-
# required for non-streaming calls with max_tokens > ~21k,
306-
# which the SDK otherwise rejects.
307-
timeout=httpx.Timeout(600.0, connect=5.0),
308-
betas=["context-management-2025-06-27"],
309-
context_management={
310-
"edits": [
311-
{
312-
"type": "clear_tool_uses_20250919",
313-
"trigger": {"type": "tool_uses", "value": 20},
314-
"keep": {"type": "tool_uses", "value": 10},
315-
"exclude_tools": ["search_documentation"],
316-
"clear_tool_inputs": True,
317-
}
318-
]
319-
},
306+
response = call_with_model_fallback(
307+
lambda m: self.client.beta.messages.create(
308+
model=m,
309+
max_tokens=self.max_tokens,
310+
system=system_prompt,
311+
messages=messages,
312+
tools=self.tools,
313+
thinking={"type": "adaptive"},
314+
output_config={"effort": "medium"},
315+
# Per-request timeout (same values as the SDK default):
316+
# required for non-streaming calls with max_tokens > ~21k,
317+
# which the SDK otherwise rejects.
318+
timeout=httpx.Timeout(600.0, connect=5.0),
319+
betas=["context-management-2025-06-27"],
320+
context_management={
321+
"edits": [
322+
{
323+
"type": "clear_tool_uses_20250919",
324+
"trigger": {"type": "tool_uses", "value": 20},
325+
"keep": {"type": "tool_uses", "value": 10},
326+
"exclude_tools": ["search_documentation"],
327+
"clear_tool_inputs": True,
328+
}
329+
]
330+
},
331+
),
332+
preferred=self.model,
320333
)
321334
return response, []
322335

services/job_chat/job_chat.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
STATUS_WORKING,
2929
STATUS_WRITING_CODE,
3030
)
31-
from models import preferred_chat_model
31+
from models import (
32+
preferred_chat_model,
33+
call_with_model_fallback,
34+
stream_with_model_fallback,
35+
)
3236

3337
_MODEL = preferred_chat_model("job_chat")
3438

@@ -231,26 +235,18 @@ def generate(
231235
with sentry_sdk.start_span(description="anthropic_api_call"):
232236
if stream:
233237
logger.info("Making streaming API call")
234-
text_started = False
235-
sent_length = 0
236-
accumulated_response = ""
237-
self._stream_applied = False
238-
self._stream_suggested_code = None
239-
self._stream_diff = None
240-
241238
original_code = context.get("expression") if context and isinstance(context, dict) else None
242239

243-
stream_kwargs = dict(
244-
max_tokens=self.config.max_tokens,
245-
messages=prompt,
246-
model=self.config.model,
247-
system=system_message,
248-
thinking={"type": "adaptive"},
249-
output_config=output_config,
250-
**tool_kwargs
251-
)
240+
def _consume(stream_obj, commit):
241+
# Reset per attempt so a model fallback never reuses a
242+
# prior (failed) stream's partial state.
243+
text_started = False
244+
sent_length = 0
245+
accumulated_response = ""
246+
self._stream_applied = False
247+
self._stream_suggested_code = None
248+
self._stream_diff = None
252249

253-
with self.client.messages.stream(**stream_kwargs) as stream_obj:
254250
for event in stream_obj:
255251
if event.type == "message_start":
256252
stream_manager.send_thinking(STATUS_WORKING)
@@ -268,20 +264,40 @@ def generate(
268264
original_code,
269265
content
270266
)
271-
message = stream_obj.get_final_message()
267+
# Once user-facing text has streamed, we can't cleanly
268+
# fall back to another model without re-sending it.
269+
if text_started:
270+
commit()
271+
272+
msg = stream_obj.get_final_message()
273+
274+
# Flush any remaining buffered text, stripping JSON closing chars
275+
if suggest_code and text_started:
276+
if sent_length < len(accumulated_response):
277+
remaining = accumulated_response[sent_length:]
278+
remaining = re.sub(r'"\s*}\s*$', '', remaining)
279+
if remaining:
280+
stream_manager.send_text(self._unescape_json_string(remaining))
281+
return msg
272282

273-
# Flush any remaining buffered text, stripping JSON closing chars
274-
if suggest_code and text_started:
275-
if sent_length < len(accumulated_response):
276-
remaining = accumulated_response[sent_length:]
277-
remaining = re.sub(r'"\s*}\s*$', '', remaining)
278-
if remaining:
279-
stream_manager.send_text(self._unescape_json_string(remaining))
283+
stream_kwargs = dict(
284+
max_tokens=self.config.max_tokens,
285+
messages=prompt,
286+
system=system_message,
287+
thinking={"type": "adaptive"},
288+
output_config=output_config,
289+
**tool_kwargs
290+
)
291+
message = stream_with_model_fallback(
292+
lambda m: self.client.messages.stream(model=m, **stream_kwargs),
293+
_consume,
294+
preferred=self.config.model,
295+
)
280296

281297
else:
282298
logger.info("Making non-streaming API call")
283299
create_kwargs = dict(
284-
max_tokens=self.config.max_tokens, messages=prompt, model=self.config.model, system=system_message,
300+
max_tokens=self.config.max_tokens, messages=prompt, system=system_message,
285301
thinking={"type": "adaptive"},
286302
output_config=output_config,
287303
# Per-request timeout (same values as the SDK default):
@@ -290,7 +306,10 @@ def generate(
290306
timeout=httpx.Timeout(600.0, connect=5.0),
291307
**tool_kwargs
292308
)
293-
message = self.client.messages.create(**create_kwargs)
309+
message = call_with_model_fallback(
310+
lambda m: self.client.messages.create(model=m, **create_kwargs),
311+
preferred=self.config.model,
312+
)
294313

295314
if hasattr(message, "usage"):
296315
if message.usage.cache_creation_input_tokens:
@@ -537,13 +556,16 @@ def try_error_correction(self, content: str, error_message: str, old_code: str,
537556
# structured outputs removed here too (see note in generate); the
538557
# correction prompt already instructs the {explanation, corrected_*}
539558
# JSON shape and json.loads below is wrapped in try/except.
540-
message = self.client.messages.create(
541-
max_tokens=16384,
542-
messages=prompt,
543-
model=self.config.model,
544-
system=system_message,
545-
output_config={"effort": "medium"},
546-
thinking={"type": "adaptive"}
559+
message = call_with_model_fallback(
560+
lambda m: self.client.messages.create(
561+
max_tokens=16384,
562+
messages=prompt,
563+
model=m,
564+
system=system_message,
565+
output_config={"effort": "medium"},
566+
thinking={"type": "adaptive"}
567+
),
568+
preferred=self.config.model,
547569
)
548570

549571
response = "\n\n".join([block.text for block in message.content if block.type == "text"])

0 commit comments

Comments
 (0)