2828 STATUS_WORKING ,
2929 STATUS_WRITING_CODE ,
3030)
31- from models import preferred_chat_model
31+ from models import (
32+ preferred_chat_model ,
33+ call_with_model_fallback ,
34+ stream_with_model_fallback ,
35+ )
3236
3337_MODEL = preferred_chat_model ("job_chat" )
3438
@@ -231,26 +235,18 @@ def generate(
231235 with sentry_sdk .start_span (description = "anthropic_api_call" ):
232236 if stream :
233237 logger .info ("Making streaming API call" )
234- text_started = False
235- sent_length = 0
236- accumulated_response = ""
237- self ._stream_applied = False
238- self ._stream_suggested_code = None
239- self ._stream_diff = None
240-
241238 original_code = context .get ("expression" ) if context and isinstance (context , dict ) else None
242239
243- stream_kwargs = dict (
244- max_tokens = self . config . max_tokens ,
245- messages = prompt ,
246- model = self . config . model ,
247- system = system_message ,
248- thinking = { "type" : "adaptive" },
249- output_config = output_config ,
250- ** tool_kwargs
251- )
240+ def _consume ( stream_obj , commit ):
241+ # Reset per attempt so a model fallback never reuses a
242+ # prior (failed) stream's partial state.
243+ text_started = False
244+ sent_length = 0
245+ accumulated_response = ""
246+ self . _stream_applied = False
247+ self . _stream_suggested_code = None
248+ self . _stream_diff = None
252249
253- with self .client .messages .stream (** stream_kwargs ) as stream_obj :
254250 for event in stream_obj :
255251 if event .type == "message_start" :
256252 stream_manager .send_thinking (STATUS_WORKING )
@@ -268,20 +264,40 @@ def generate(
268264 original_code ,
269265 content
270266 )
271- message = stream_obj .get_final_message ()
267+ # Once user-facing text has streamed, we can't cleanly
268+ # fall back to another model without re-sending it.
269+ if text_started :
270+ commit ()
271+
272+ msg = stream_obj .get_final_message ()
273+
274+ # Flush any remaining buffered text, stripping JSON closing chars
275+ if suggest_code and text_started :
276+ if sent_length < len (accumulated_response ):
277+ remaining = accumulated_response [sent_length :]
278+ remaining = re .sub (r'"\s*}\s*$' , '' , remaining )
279+ if remaining :
280+ stream_manager .send_text (self ._unescape_json_string (remaining ))
281+ return msg
272282
273- # Flush any remaining buffered text, stripping JSON closing chars
274- if suggest_code and text_started :
275- if sent_length < len (accumulated_response ):
276- remaining = accumulated_response [sent_length :]
277- remaining = re .sub (r'"\s*}\s*$' , '' , remaining )
278- if remaining :
279- stream_manager .send_text (self ._unescape_json_string (remaining ))
283+ stream_kwargs = dict (
284+ max_tokens = self .config .max_tokens ,
285+ messages = prompt ,
286+ system = system_message ,
287+ thinking = {"type" : "adaptive" },
288+ output_config = output_config ,
289+ ** tool_kwargs
290+ )
291+ message = stream_with_model_fallback (
292+ lambda m : self .client .messages .stream (model = m , ** stream_kwargs ),
293+ _consume ,
294+ preferred = self .config .model ,
295+ )
280296
281297 else :
282298 logger .info ("Making non-streaming API call" )
283299 create_kwargs = dict (
284- max_tokens = self .config .max_tokens , messages = prompt , model = self . config . model , system = system_message ,
300+ max_tokens = self .config .max_tokens , messages = prompt , system = system_message ,
285301 thinking = {"type" : "adaptive" },
286302 output_config = output_config ,
287303 # Per-request timeout (same values as the SDK default):
@@ -290,7 +306,10 @@ def generate(
290306 timeout = httpx .Timeout (600.0 , connect = 5.0 ),
291307 ** tool_kwargs
292308 )
293- message = self .client .messages .create (** create_kwargs )
309+ message = call_with_model_fallback (
310+ lambda m : self .client .messages .create (model = m , ** create_kwargs ),
311+ preferred = self .config .model ,
312+ )
294313
295314 if hasattr (message , "usage" ):
296315 if message .usage .cache_creation_input_tokens :
@@ -537,13 +556,16 @@ def try_error_correction(self, content: str, error_message: str, old_code: str,
537556 # structured outputs removed here too (see note in generate); the
538557 # correction prompt already instructs the {explanation, corrected_*}
539558 # JSON shape and json.loads below is wrapped in try/except.
540- message = self .client .messages .create (
541- max_tokens = 16384 ,
542- messages = prompt ,
543- model = self .config .model ,
544- system = system_message ,
545- output_config = {"effort" : "medium" },
546- thinking = {"type" : "adaptive" }
559+ message = call_with_model_fallback (
560+ lambda m : self .client .messages .create (
561+ max_tokens = 16384 ,
562+ messages = prompt ,
563+ model = m ,
564+ system = system_message ,
565+ output_config = {"effort" : "medium" },
566+ thinking = {"type" : "adaptive" }
567+ ),
568+ preferred = self .config .model ,
547569 )
548570
549571 response = "\n \n " .join ([block .text for block in message .content if block .type == "text" ])
0 commit comments