diff --git a/services/doc_agent_chat/agent.py b/services/doc_agent_chat/agent.py index 2e06e437..112b0870 100644 --- a/services/doc_agent_chat/agent.py +++ b/services/doc_agent_chat/agent.py @@ -7,7 +7,7 @@ from doc_agent_chat.prompt import build_system_prompt from doc_agent_chat.tools import TOOL_DEFINITIONS, search_documents, format_search_results_as_documents from doc_agent_chat.config_loader import ConfigLoader -from models import preferred_chat_model +from models import preferred_chat_model, call_with_model_fallback logger = create_logger("agent") @@ -59,16 +59,19 @@ def run( for iteration in range(self.max_tool_calls): logger.info(f"Agentic loop iteration {iteration + 1}") - response = self.client.messages.create( - model=self.model, - max_tokens=self.max_tokens, - system=system_prompt, - messages=messages, - tools=TOOL_DEFINITIONS, - # Per-request timeout (same values as the SDK default): - # required for non-streaming calls with max_tokens > ~21k, - # which the SDK otherwise rejects. - timeout=httpx.Timeout(600.0, connect=5.0), + response = call_with_model_fallback( + lambda m: self.client.messages.create( + model=m, + max_tokens=self.max_tokens, + system=system_prompt, + messages=messages, + tools=TOOL_DEFINITIONS, + # Per-request timeout (same values as the SDK default): + # required for non-streaming calls with max_tokens > ~21k, + # which the SDK otherwise rejects. + timeout=httpx.Timeout(600.0, connect=5.0), + ), + preferred=self.model, ) if hasattr(response, "usage"): diff --git a/services/global_chat/planner.py b/services/global_chat/planner.py index 9e806a91..55a82c52 100644 --- a/services/global_chat/planner.py +++ b/services/global_chat/planner.py @@ -24,7 +24,11 @@ STATUS_PLANNING, ) from global_chat.config_loader import ConfigLoader -from models import preferred_chat_model +from models import ( + preferred_chat_model, + call_with_model_fallback, + stream_with_model_fallback, +) from global_chat.tools.tool_definitions import TOOL_DEFINITIONS from global_chat.yaml_utils import stitch_job_code, redact_job_bodies, find_job_in_yaml from tools.search_documentation.search_documentation import search_documentation_tool @@ -276,47 +280,56 @@ def _call_api(self, system_prompt, messages, stream): task-specific status messages sent before each tool execution. """ if stream: - buffered_text = [] - - with self.client.messages.stream( - model=self.model, - max_tokens=self.max_tokens, - system=system_prompt, - messages=messages, - tools=self.tools, - thinking={"type": "adaptive"}, - output_config={"effort": "medium"}, - ) as stream_obj: + def _consume(stream_obj, commit): + buffered_text = [] for event in stream_obj: if event.type == "content_block_delta": if event.delta.type == "text_delta": buffered_text.append(event.delta.text) + commit() return stream_obj.get_final_message(), buffered_text + + return stream_with_model_fallback( + lambda m: self.client.messages.stream( + model=m, + max_tokens=self.max_tokens, + system=system_prompt, + messages=messages, + tools=self.tools, + thinking={"type": "adaptive"}, + output_config={"effort": "medium"}, + ), + _consume, + preferred=self.model, + ) else: - response = self.client.beta.messages.create( - model=self.model, - max_tokens=self.max_tokens, - system=system_prompt, - messages=messages, - tools=self.tools, - thinking={"type": "adaptive"}, - output_config={"effort": "medium"}, - # Per-request timeout (same values as the SDK default): - # required for non-streaming calls with max_tokens > ~21k, - # which the SDK otherwise rejects. - timeout=httpx.Timeout(600.0, connect=5.0), - betas=["context-management-2025-06-27"], - context_management={ - "edits": [ - { - "type": "clear_tool_uses_20250919", - "trigger": {"type": "tool_uses", "value": 20}, - "keep": {"type": "tool_uses", "value": 10}, - "exclude_tools": ["search_documentation"], - "clear_tool_inputs": True, - } - ] - }, + response = call_with_model_fallback( + lambda m: self.client.beta.messages.create( + model=m, + max_tokens=self.max_tokens, + system=system_prompt, + messages=messages, + tools=self.tools, + thinking={"type": "adaptive"}, + output_config={"effort": "medium"}, + # Per-request timeout (same values as the SDK default): + # required for non-streaming calls with max_tokens > ~21k, + # which the SDK otherwise rejects. + timeout=httpx.Timeout(600.0, connect=5.0), + betas=["context-management-2025-06-27"], + context_management={ + "edits": [ + { + "type": "clear_tool_uses_20250919", + "trigger": {"type": "tool_uses", "value": 20}, + "keep": {"type": "tool_uses", "value": 10}, + "exclude_tools": ["search_documentation"], + "clear_tool_inputs": True, + } + ] + }, + ), + preferred=self.model, ) return response, [] diff --git a/services/job_chat/job_chat.py b/services/job_chat/job_chat.py index cc569590..407ead44 100644 --- a/services/job_chat/job_chat.py +++ b/services/job_chat/job_chat.py @@ -28,7 +28,11 @@ STATUS_WORKING, STATUS_WRITING_CODE, ) -from models import preferred_chat_model +from models import ( + preferred_chat_model, + call_with_model_fallback, + stream_with_model_fallback, +) _MODEL = preferred_chat_model("job_chat") @@ -231,26 +235,18 @@ def generate( with sentry_sdk.start_span(description="anthropic_api_call"): if stream: logger.info("Making streaming API call") - text_started = False - sent_length = 0 - accumulated_response = "" - self._stream_applied = False - self._stream_suggested_code = None - self._stream_diff = None - original_code = context.get("expression") if context and isinstance(context, dict) else None - stream_kwargs = dict( - max_tokens=self.config.max_tokens, - messages=prompt, - model=self.config.model, - system=system_message, - thinking={"type": "adaptive"}, - output_config=output_config, - **tool_kwargs - ) + def _consume(stream_obj, commit): + # Reset per attempt so a model fallback never reuses a + # prior (failed) stream's partial state. + text_started = False + sent_length = 0 + accumulated_response = "" + self._stream_applied = False + self._stream_suggested_code = None + self._stream_diff = None - with self.client.messages.stream(**stream_kwargs) as stream_obj: for event in stream_obj: if event.type == "message_start": stream_manager.send_thinking(STATUS_WORKING) @@ -268,20 +264,40 @@ def generate( original_code, content ) - message = stream_obj.get_final_message() + # Once user-facing text has streamed, we can't cleanly + # fall back to another model without re-sending it. + if text_started: + commit() + + msg = stream_obj.get_final_message() + + # Flush any remaining buffered text, stripping JSON closing chars + if suggest_code and text_started: + if sent_length < len(accumulated_response): + remaining = accumulated_response[sent_length:] + remaining = re.sub(r'"\s*}\s*$', '', remaining) + if remaining: + stream_manager.send_text(self._unescape_json_string(remaining)) + return msg - # Flush any remaining buffered text, stripping JSON closing chars - if suggest_code and text_started: - if sent_length < len(accumulated_response): - remaining = accumulated_response[sent_length:] - remaining = re.sub(r'"\s*}\s*$', '', remaining) - if remaining: - stream_manager.send_text(self._unescape_json_string(remaining)) + stream_kwargs = dict( + max_tokens=self.config.max_tokens, + messages=prompt, + system=system_message, + thinking={"type": "adaptive"}, + output_config=output_config, + **tool_kwargs + ) + message = stream_with_model_fallback( + lambda m: self.client.messages.stream(model=m, **stream_kwargs), + _consume, + preferred=self.config.model, + ) else: logger.info("Making non-streaming API call") create_kwargs = dict( - max_tokens=self.config.max_tokens, messages=prompt, model=self.config.model, system=system_message, + max_tokens=self.config.max_tokens, messages=prompt, system=system_message, thinking={"type": "adaptive"}, output_config=output_config, # Per-request timeout (same values as the SDK default): @@ -290,7 +306,10 @@ def generate( timeout=httpx.Timeout(600.0, connect=5.0), **tool_kwargs ) - message = self.client.messages.create(**create_kwargs) + message = call_with_model_fallback( + lambda m: self.client.messages.create(model=m, **create_kwargs), + preferred=self.config.model, + ) if hasattr(message, "usage"): if message.usage.cache_creation_input_tokens: @@ -537,13 +556,16 @@ def try_error_correction(self, content: str, error_message: str, old_code: str, # structured outputs removed here too (see note in generate); the # correction prompt already instructs the {explanation, corrected_*} # JSON shape and json.loads below is wrapped in try/except. - message = self.client.messages.create( - max_tokens=16384, - messages=prompt, - model=self.config.model, - system=system_message, - output_config={"effort": "medium"}, - thinking={"type": "adaptive"} + message = call_with_model_fallback( + lambda m: self.client.messages.create( + max_tokens=16384, + messages=prompt, + model=m, + system=system_message, + output_config={"effort": "medium"}, + thinking={"type": "adaptive"} + ), + preferred=self.config.model, ) response = "\n\n".join([block.text for block in message.content if block.type == "text"]) diff --git a/services/models.py b/services/models.py index e9777598..5fcd23a9 100644 --- a/services/models.py +++ b/services/models.py @@ -4,16 +4,22 @@ """ import os +from collections.abc import Callable +from typing import Any + +import anthropic CLAUDE_MODELS: dict[str, str] = { - "claude-opus": "claude-opus-4-8", - "claude-sonnet": "claude-sonnet-4-6", - "claude-haiku": "claude-haiku-4-5-20251001", + "claude-opus": "claude-opus-4-8", + "claude-opus-prev": "claude-opus-4-7", + "claude-sonnet": "claude-sonnet-4-6", + "claude-haiku": "claude-haiku-4-5-20251001", } -CLAUDE_OPUS: str = CLAUDE_MODELS["claude-opus"] -CLAUDE_SONNET: str = CLAUDE_MODELS["claude-sonnet"] -CLAUDE_HAIKU: str = CLAUDE_MODELS["claude-haiku"] +CLAUDE_OPUS: str = CLAUDE_MODELS["claude-opus"] +CLAUDE_OPUS_PREV: str = CLAUDE_MODELS["claude-opus-prev"] +CLAUDE_SONNET: str = CLAUDE_MODELS["claude-sonnet"] +CLAUDE_HAIKU: str = CLAUDE_MODELS["claude-haiku"] def resolve_model(alias: str) -> str: @@ -65,3 +71,116 @@ def preferred_chat_model(service: str | None = None) -> str: return resolve_model(override) return cfg.get("default", CHAT_MODEL_DEFAULT) + + +# --- Fallback when the preferred model is unavailable ------------------------ +# +# When a chat call's preferred model is unavailable, we try the next model in +# this chain rather than failing the request. The preferred model is tried +# first; this chain provides the ordered fallbacks after it. +CHAT_FALLBACK_CHAIN: list[str] = [CLAUDE_OPUS, CLAUDE_OPUS_PREV, CLAUDE_SONNET] + +# HTTP status codes that mean "the provider/model is down or overloaded right +# now" and we should try the next model: 500 (api_error), 502/503 (gateway/ +# unavailable), 529 (overloaded). The SDK already retries these with backoff +# before they surface here, so reaching this point means retries were exhausted. +# (The SDK version in use has no dedicated OverloadedError class — 529 arrives +# as a plain APIStatusError — so we classify by status code.) +_MODEL_DOWN_STATUS_CODES = frozenset({500, 502, 503, 529}) + + +def chat_model_chain(preferred: str | None = None) -> list[str]: + """Ordered models to try for a main-chat call: the preferred model first, + then the fallback chain, de-duplicated while preserving order.""" + first = preferred or preferred_chat_model() + return list(dict.fromkeys([first, *CHAT_FALLBACK_CHAIN])) + + +def _fallback_steps(preferred: str | None) -> list[tuple[str, str | None]]: + """The chain as (model, next_model) pairs; next_model is None for the last + model, which is what tells the fallback loops "no more models to try".""" + models = chat_model_chain(preferred) + return list(zip(models, [*models[1:], None], strict=True)) + + +def is_model_unavailable_error(exc: BaseException) -> bool: + """True if `exc` means the model itself is unavailable and we should fall + back to the next model rather than surfacing the error. + + Covers a removed/renamed model (404 not_found — permanent, not retried by + the SDK) and a down/overloaded provider (500/502/503/529 — transient, + surfaced only after the SDK's own retries are exhausted). + """ + if isinstance(exc, anthropic.NotFoundError): + return True + if isinstance(exc, anthropic.APIStatusError): + return getattr(exc, "status_code", None) in _MODEL_DOWN_STATUS_CODES + return False + + +def _alert_model_fallback(failed_model: str, next_model: str, exc: BaseException) -> None: + """Flag a fallback to Sentry. A fallback is meant to be temporary — it keeps + chat up, but someone should fix the preferred model (or the env override) soon.""" + msg = ( + f"Chat model {failed_model!r} unavailable ({type(exc).__name__}); " + f"falling back to {next_model!r}" + ) + try: + import sentry_sdk # noqa: PLC0415 + + sentry_sdk.capture_message(msg, level="warning") + except Exception: + pass + # Also surface in private logs (print, not the client-facing logger). + print(msg) # noqa: T201 + + +def call_with_model_fallback(attempt: Callable[[str], Any], *, preferred: str | None = None) -> Any: # noqa: ANN401 + """Run `attempt(model)` against the chat model chain, advancing to the next + model only on model-unavailable errors. Returns `attempt`'s result; re-raises + the original error for non-fallback errors or once the chain is exhausted. + + For non-streaming calls. For streaming, use `stream_with_model_fallback`. + """ + for model, next_model in _fallback_steps(preferred): + try: + return attempt(model) + except Exception as exc: + if next_model is None or not is_model_unavailable_error(exc): + raise + _alert_model_fallback(model, next_model, exc) + + +def stream_with_model_fallback( + open_stream: Callable[[str], Any], + consume: Callable[..., Any], + *, + preferred: str | None = None, +) -> Any: # noqa: ANN401 + """Streaming counterpart of `call_with_model_fallback`. + + Args: + open_stream: `open_stream(model)` -> the result of `client.messages.stream(...)` + (a context manager, not yet entered). + consume: `consume(stream_obj, commit)` -> processes events and returns a + result. Call `commit()` as soon as any user-facing content has been + sent; after that a failure re-raises instead of falling back, so we + never re-stream a partial answer to the user. + + Fallback only happens for failures at stream-open / before the first + committed content (the case where a removed or overloaded model surfaces). + """ + for model, next_model in _fallback_steps(preferred): + committed = False + + def commit() -> None: + nonlocal committed + committed = True + + try: + with open_stream(model) as stream_obj: + return consume(stream_obj, commit) + except Exception as exc: + if committed or next_model is None or not is_model_unavailable_error(exc): + raise + _alert_model_fallback(model, next_model, exc) diff --git a/services/tests/unit/test_models.py b/services/tests/unit/test_models.py index b7639a22..ec438592 100644 --- a/services/tests/unit/test_models.py +++ b/services/tests/unit/test_models.py @@ -1,15 +1,57 @@ -"""Unit tests for the central chat-model selection in `services/models.py`. +"""Unit tests for the central chat-model selection and fallback logic in +`services/models.py`. -No real model calls, pure resolution logic. The repo-root conftest marks -everything under a `unit/` dir as `unit` and blocks real client construction. +No real model calls: errors are fabricated anthropic exception instances and the +work is done by fake `attempt` / stream callables. The repo-root conftest marks +everything under a `unit/` dir as `unit` and blocks real client construction, so +these stay fast and offline. """ +import anthropic import models as m import pytest _WORKFLOW_ENV = m.CHAT_SERVICE_MODELS["workflow_chat"]["env"] +# --- helpers ---------------------------------------------------------------- + +def _not_found() -> anthropic.NotFoundError: + """A 404 NotFoundError without going through the HTTP-bound __init__.""" + return anthropic.NotFoundError.__new__(anthropic.NotFoundError) + + +def _status_error(code: int) -> anthropic.APIStatusError: + """An APIStatusError carrying `code`, without a real HTTP response.""" + exc = anthropic.APIStatusError.__new__(anthropic.APIStatusError) + exc.status_code = code + return exc + + +class FakeStreamCM: + """Stands in for the context manager `client.messages.stream(...)` returns. + + Raises `open_error` on __enter__ (mimicking a model-unavailable error at + stream open) and records whether it was entered/exited so tests can assert + the `with` block was cleaned up. + """ + + def __init__(self, *, open_error: BaseException | None = None) -> None: + self.open_error = open_error + self.entered = False + self.exited = False + + def __enter__(self): + self.entered = True + if self.open_error is not None: + raise self.open_error + return self + + def __exit__(self, *_exc: object) -> bool: + self.exited = True + return False + + @pytest.fixture(autouse=True) def _clear_env(monkeypatch): """Clear all per-service overrides so the real environment can't skew tests.""" @@ -17,6 +59,8 @@ def _clear_env(monkeypatch): monkeypatch.delenv(cfg["env"], raising=False) +# --- preferred_chat_model: defaults + precedence ---------------------------- + def test_unlisted_service_uses_default(): # A service with no entry (e.g. doc_agent_chat, or none at all) uses the default. assert m.preferred_chat_model() == m.CHAT_MODEL_DEFAULT @@ -40,3 +84,208 @@ def test_env_var_is_scoped_to_one_service(monkeypatch): monkeypatch.setenv(_WORKFLOW_ENV, "claude-haiku") assert m.preferred_chat_model("workflow_chat") == m.CLAUDE_HAIKU assert m.preferred_chat_model("job_chat") == m.CLAUDE_OPUS # unaffected + + +# --- chat_model_chain: order + dedup ---------------------------------------- + +def test_chain_default_order(): + assert m.chat_model_chain() == [m.CLAUDE_OPUS, m.CLAUDE_OPUS_PREV, m.CLAUDE_SONNET] + + +def test_chain_dedupes_when_preferred_already_in_chain(): + # sonnet is already in the fallback chain; it should appear once, first. + assert m.chat_model_chain(m.CLAUDE_SONNET) == [ + m.CLAUDE_SONNET, + m.CLAUDE_OPUS, + m.CLAUDE_OPUS_PREV, + ] + + +def test_chain_prepends_a_novel_preferred_model(): + chain = m.chat_model_chain("claude-haiku-4-5-20251001") + assert chain[0] == "claude-haiku-4-5-20251001" + assert chain[1:] == m.CHAT_FALLBACK_CHAIN + + +# --- is_model_unavailable_error: which errors trigger fallback -------------- + +def test_not_found_is_unavailable(): + # 404: model removed/renamed (e.g. a model that was taken down). + assert m.is_model_unavailable_error(_not_found()) is True + + +@pytest.mark.parametrize("code", [500, 502, 503, 529]) +def test_provider_down_codes_are_unavailable(code): + assert m.is_model_unavailable_error(_status_error(code)) is True + + +# 400 = generic client error; 429 = rate limit, the important "do NOT fall back" +# case (a busy account isn't a reason to switch models). +@pytest.mark.parametrize("code", [400, 429]) +def test_other_status_codes_are_not_unavailable(code): + assert m.is_model_unavailable_error(_status_error(code)) is False + + +def test_connection_and_generic_errors_are_not_unavailable(): + # A network blip is not model-specific; falling back to another model won't help. + conn = anthropic.APIConnectionError.__new__(anthropic.APIConnectionError) + assert m.is_model_unavailable_error(conn) is False + assert m.is_model_unavailable_error(ValueError("nope")) is False + + +# --- call_with_model_fallback (non-streaming) ------------------------------- + +def test_call_returns_first_model_result_without_fallback(monkeypatch): + alerts = [] + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: alerts.append(a)) + tried = [] + + def attempt(model): + tried.append(model) + return f"ok:{model}" + + assert m.call_with_model_fallback(attempt) == f"ok:{m.CLAUDE_OPUS}" + assert tried == [m.CLAUDE_OPUS] + assert alerts == [] # no fallback => no alert + + +def test_call_falls_back_on_404_and_alerts(monkeypatch): + alerts = [] + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: alerts.append(a)) + tried = [] + + def attempt(model): + tried.append(model) + if model == m.CLAUDE_OPUS: + raise _not_found() + return f"ok:{model}" + + assert m.call_with_model_fallback(attempt) == f"ok:{m.CLAUDE_OPUS_PREV}" + assert tried == [m.CLAUDE_OPUS, m.CLAUDE_OPUS_PREV] + # one alert, naming the failed and next model + assert len(alerts) == 1 + failed, nxt, _exc = alerts[0] + assert (failed, nxt) == (m.CLAUDE_OPUS, m.CLAUDE_OPUS_PREV) + + +def test_call_does_not_fall_back_on_non_model_error(monkeypatch): + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + tried = [] + + def attempt(model): + tried.append(model) + raise _status_error(400) + + with pytest.raises(anthropic.APIStatusError): + m.call_with_model_fallback(attempt) + assert tried == [m.CLAUDE_OPUS] # propagated immediately, no fallback + + +def test_call_raises_last_error_when_whole_chain_down(monkeypatch): + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + tried = [] + + def attempt(model): + tried.append(model) + raise _status_error(529) + + with pytest.raises(anthropic.APIStatusError) as excinfo: + m.call_with_model_fallback(attempt) + assert excinfo.value.status_code == 529 + assert tried == m.chat_model_chain() # every model attempted + + +def test_call_respects_preferred_model(monkeypatch): + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + tried = [] + + def attempt(model): + tried.append(model) + return model + + assert m.call_with_model_fallback(attempt, preferred=m.CLAUDE_SONNET) == m.CLAUDE_SONNET + assert tried == [m.CLAUDE_SONNET] + + +# --- stream_with_model_fallback --------------------------------------------- + +def test_stream_falls_back_when_open_fails(monkeypatch): + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + opened = [] + cms = {} + + def open_stream(model): + opened.append(model) + cm = FakeStreamCM(open_error=_not_found() if model == m.CLAUDE_OPUS else None) + cms[model] = cm + return cm + + def consume(stream, commit): + commit() + return f"streamed:{stream}" + + result = m.stream_with_model_fallback(open_stream, consume, preferred=m.CLAUDE_OPUS) + assert opened == [m.CLAUDE_OPUS, m.CLAUDE_OPUS_PREV] + assert result.startswith("streamed:") + # __enter__ raising means the `with` never calls __exit__ on the failed CM + # (standard context-manager semantics); the model that opened successfully + # is the one that gets exited cleanly. + assert cms[m.CLAUDE_OPUS].entered is True + assert cms[m.CLAUDE_OPUS].exited is False + assert cms[m.CLAUDE_OPUS_PREV].exited is True + + +def test_stream_does_not_fall_back_after_commit(monkeypatch): + # Once user-facing content is sent, a later model-unavailable error must NOT + # trigger a fallback (we won't re-stream a partial answer). + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + opened = [] + + def open_stream(model): + opened.append(model) + return FakeStreamCM() + + def consume(stream, commit): + commit() # we've shown the user something + raise _status_error(529) # ...then the model dies mid-stream + + with pytest.raises(anthropic.APIStatusError): + m.stream_with_model_fallback(open_stream, consume, preferred=m.CLAUDE_OPUS) + assert opened == [m.CLAUDE_OPUS] # no second model tried + + +def test_stream_falls_back_on_pre_commit_consume_failure(monkeypatch): + # Failure during consume but before any content was committed is still safe + # to fall back on. + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + opened = [] + + def open_stream(model): + opened.append(model) + return FakeStreamCM() + + def consume(stream, commit): + if opened[-1] == m.CLAUDE_OPUS: + raise _status_error(503) # before commit() + commit() + return "ok" + + assert m.stream_with_model_fallback(open_stream, consume, preferred=m.CLAUDE_OPUS) == "ok" + assert opened == [m.CLAUDE_OPUS, m.CLAUDE_OPUS_PREV] + + +def test_stream_does_not_fall_back_on_non_model_error(monkeypatch): + monkeypatch.setattr(m, "_alert_model_fallback", lambda *a: None) + opened = [] + + def open_stream(model): + opened.append(model) + return FakeStreamCM(open_error=_status_error(400)) + + def consume(stream, commit): # pragma: no cover - never reached + commit() + return "ok" + + with pytest.raises(anthropic.APIStatusError): + m.stream_with_model_fallback(open_stream, consume) + assert opened == [m.CLAUDE_OPUS] diff --git a/services/workflow_chat/workflow_chat.py b/services/workflow_chat/workflow_chat.py index 281a2af0..294d8505 100644 --- a/services/workflow_chat/workflow_chat.py +++ b/services/workflow_chat/workflow_chat.py @@ -6,7 +6,11 @@ from typing import List, Optional, Dict, Any import yaml from dataclasses import dataclass -from models import preferred_chat_model +from models import ( + preferred_chat_model, + call_with_model_fallback, + stream_with_model_fallback, +) _MODEL = preferred_chat_model("workflow_chat") @@ -217,18 +221,10 @@ def generate( else: stream_manager.send_thinking(STATUS_NEW_WORKFLOW + STATUS_DESIGNING_WORKFLOW) - text_started = False - sent_length = 0 - accumulated_response = "" - - with self.client.messages.stream( - max_tokens=self.config.max_tokens, - messages=prompt, - model=self.config.model, - system=system_message, - output_config=output_config, - thinking={"type": "adaptive"} - ) as stream_obj: + def _consume(stream_obj, commit): + text_started = False + sent_length = 0 + accumulated_response = "" for event in stream_obj: accumulated_response, text_started, sent_length = self.process_stream_event( event, @@ -238,26 +234,48 @@ def generate( stream_manager, preserved_values ) - message = stream_obj.get_final_message() - - # Flush any remaining buffered text, stripping JSON closing chars - if text_started: - if sent_length < len(accumulated_response): - remaining = accumulated_response[sent_length:] - remaining = re.sub(r'"\s*}\s*$', '', remaining) - if remaining: - stream_manager.send_text(self._unescape_json_string(remaining)) + # Once user-facing text has streamed, we can't + # cleanly fall back without re-sending it. + if text_started: + commit() + + msg = stream_obj.get_final_message() + + # Flush any remaining buffered text, stripping JSON closing chars + if text_started: + if sent_length < len(accumulated_response): + remaining = accumulated_response[sent_length:] + remaining = re.sub(r'"\s*}\s*$', '', remaining) + if remaining: + stream_manager.send_text(self._unescape_json_string(remaining)) + return msg + + message = stream_with_model_fallback( + lambda m: self.client.messages.stream( + max_tokens=self.config.max_tokens, + messages=prompt, + model=m, + system=system_message, + output_config=output_config, + thinking={"type": "adaptive"} + ), + _consume, + preferred=self.config.model, + ) else: logger.info("Making non-streaming API call") - message = self.client.messages.create( - max_tokens=self.config.max_tokens, messages=prompt, model=self.config.model, system=system_message, - output_config=output_config, - thinking={"type": "adaptive"}, - # Per-request timeout (same values as the SDK default): - # required for non-streaming calls with max_tokens > ~21k, - # which the SDK otherwise rejects. - timeout=httpx.Timeout(600.0, connect=5.0), + message = call_with_model_fallback( + lambda m: self.client.messages.create( + max_tokens=self.config.max_tokens, messages=prompt, model=m, system=system_message, + output_config=output_config, + thinking={"type": "adaptive"}, + # Per-request timeout (same values as the SDK default): + # required for non-streaming calls with max_tokens > ~21k, + # which the SDK otherwise rejects. + timeout=httpx.Timeout(600.0, connect=5.0), + ), + preferred=self.config.model, ) # Track usage from this attempt