diff --git a/codex-rs/app-server-protocol/src/protocol/v2/thread.rs b/codex-rs/app-server-protocol/src/protocol/v2/thread.rs index 0be7776279b0..e712ab0ad3e1 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2/thread.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2/thread.rs @@ -57,6 +57,11 @@ pub struct ThreadStartParams { pub model: Option, #[ts(optional = nullable)] pub model_provider: Option, + /// Allow a provider with an authoritative static model catalog to replace an unavailable + /// requested model with its default. + #[experimental("thread/start.allowProviderModelFallback")] + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub allow_provider_model_fallback: bool, #[serde( default, deserialize_with = "crate::protocol::serde_helpers::deserialize_double_option", diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 4c88d4b69d32..1a3684a024dc 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -137,7 +137,7 @@ Example with notification opt-out: ## API Overview -- `thread/start` — create a new thread; emits `thread/started` (including the current `thread.status`) and auto-subscribes you to turn/item events for that thread. When the request includes a `cwd` and the resolved sandbox is `workspace-write` or full access, app-server also marks that project as trusted in the user `config.toml`. Pass `sessionStartSource: "clear"` when starting a replacement thread after clearing the current session so `SessionStart` hooks receive `source: "clear"` instead of the default `"startup"`. Experimental `runtimeWorkspaceRoots` replaces the thread-scoped runtime workspace roots used to materialize `:workspace_roots`; paths must be absolute. For permissions, prefer experimental `permissions` profile selection by id; the legacy `sandbox` shorthand is still accepted but cannot be combined with `permissions`. Deprecated experimental `multiAgentMode` is ignored; use Ultra reasoning effort for proactive multi-agent behavior. Experimental `environments` selects the sticky execution environments for turns on the thread; omit it to use the server default, pass `[]` to disable environments, or pass explicit environment ids with per-environment `cwd`. Experimental `selectedCapabilityRoots` selects environment-owned plugin or standalone-skill roots using environment-native absolute paths. Skills found below those roots are listed and read through the owning environment. Stdio MCP servers declared by selected plugins are started in that environment, and HTTP MCP connections use that environment's HTTP client. +- `thread/start` — create a new thread; emits `thread/started` (including the current `thread.status`) and auto-subscribes you to turn/item events for that thread. When the request includes a `cwd` and the resolved sandbox is `workspace-write` or full access, app-server also marks that project as trusted in the user `config.toml`. Pass `sessionStartSource: "clear"` when starting a replacement thread after clearing the current session so `SessionStart` hooks receive `source: "clear"` instead of the default `"startup"`. Experimental `allowProviderModelFallback` lets providers backed by an authoritative static model catalog replace an unavailable requested `model` with the catalog default; dynamic or cached catalogs preserve the requested model. Experimental `runtimeWorkspaceRoots` replaces the thread-scoped runtime workspace roots used to materialize `:workspace_roots`; paths must be absolute. For permissions, prefer experimental `permissions` profile selection by id; the legacy `sandbox` shorthand is still accepted but cannot be combined with `permissions`. Deprecated experimental `multiAgentMode` is ignored; use Ultra reasoning effort for proactive multi-agent behavior. Experimental `environments` selects the sticky execution environments for turns on the thread; omit it to use the server default, pass `[]` to disable environments, or pass explicit environment ids with per-environment `cwd`. Experimental `selectedCapabilityRoots` selects environment-owned plugin or standalone-skill roots using environment-native absolute paths. Skills found below those roots are listed and read through the owning environment. Stdio MCP servers declared by selected plugins are started in that environment, and HTTP MCP connections use that environment's HTTP client. - `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it. Accepts the same permission override rules as `thread/start`. - `thread/fork` — fork an existing thread into a new thread id by copying the stored history; if the source thread is currently mid-turn, the fork records the same interruption marker as `turn/interrupt` instead of inheriting an unmarked partial turn suffix. The returned `thread.forkedFromId` points at the source thread when known. Accepts `ephemeral: true` for an in-memory temporary fork, emits `thread/started` (including the current `thread.status`), and auto-subscribes you to turn/item events for the new thread. Experimental clients can pass `excludeTurns: true` when they plan to page fork history via `thread/turns/list` instead of receiving the full turn array immediately. Accepts the same permission override rules as `thread/start`. - `thread/start`, `thread/resume`, and `thread/fork` responses include the legacy `sandbox` compatibility projection. `instructionSources` lists loaded instruction files using each source environment's native absolute path syntax, including files loaded from remote environments. Experimental clients can read `runtimeWorkspaceRoots` for the thread-scoped runtime roots and `activePermissionProfile` for the named or implicit built-in profile identity/provenance when known. Their deprecated experimental `multiAgentMode` field, and the corresponding thread setting, always report `explicitRequestOnly`; Ultra reasoning effort is the source of proactive multi-agent behavior. diff --git a/codex-rs/app-server/src/request_processors/external_agent_session_import.rs b/codex-rs/app-server/src/request_processors/external_agent_session_import.rs index bc0a64ef7298..15b6e3a43ac8 100644 --- a/codex-rs/app-server/src/request_processors/external_agent_session_import.rs +++ b/codex-rs/app-server/src/request_processors/external_agent_session_import.rs @@ -184,7 +184,11 @@ impl ExternalAgentSessionImporter { .map_err(|err| format!("failed to load imported session config: {err}"))?; let models_manager = self.thread_manager.get_models_manager(); let model = models_manager - .get_default_model(&config.model, RefreshStrategy::Offline) + .get_default_model( + &config.model, + /*allow_provider_model_fallback*/ false, + RefreshStrategy::Offline, + ) .await; let model_info = models_manager .get_model_info(model.as_str(), &config.to_models_manager_config()) diff --git a/codex-rs/app-server/src/request_processors/thread_processor.rs b/codex-rs/app-server/src/request_processors/thread_processor.rs index ab4e24186795..7c6b0aad741f 100644 --- a/codex-rs/app-server/src/request_processors/thread_processor.rs +++ b/codex-rs/app-server/src/request_processors/thread_processor.rs @@ -903,6 +903,7 @@ impl ThreadRequestProcessor { let ThreadStartParams { model, model_provider, + allow_provider_model_fallback, service_tier, cwd, runtime_workspace_roots, @@ -979,6 +980,7 @@ impl ThreadRequestProcessor { thread_source.map(Into::into), environment_selections, service_name, + allow_provider_model_fallback, experimental_raw_events, request_trace, ) @@ -1053,6 +1055,7 @@ impl ThreadRequestProcessor { thread_source: Option, environments: Option>, service_name: Option, + allow_provider_model_fallback: bool, experimental_raw_events: bool, request_trace: Option, ) -> Result<(), JSONRPCErrorError> { @@ -1160,6 +1163,7 @@ impl ThreadRequestProcessor { .thread_manager .start_thread_with_options(StartThreadOptions { config, + allow_provider_model_fallback, initial_history: match session_start_source .unwrap_or(codex_app_server_protocol::ThreadStartSource::Startup) { diff --git a/codex-rs/app-server/tests/suite/v2/skills_list.rs b/codex-rs/app-server/tests/suite/v2/skills_list.rs index 075cc007ac8b..64b7fe05048e 100644 --- a/codex-rs/app-server/tests/suite/v2/skills_list.rs +++ b/codex-rs/app-server/tests/suite/v2/skills_list.rs @@ -882,6 +882,7 @@ async fn skills_changed_notification_is_emitted_after_skill_change() -> Result<( .send_thread_start_request(ThreadStartParams { model: None, model_provider: None, + allow_provider_model_fallback: false, service_tier: None, cwd: None, runtime_workspace_roots: None, diff --git a/codex-rs/app-server/tests/suite/v2/thread_start.rs b/codex-rs/app-server/tests/suite/v2/thread_start.rs index 39f8935a93bf..93dcb47ede6b 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_start.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_start.rs @@ -54,6 +54,115 @@ use super::analytics::wait_for_analytics_payload; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); const INVALID_REQUEST_ERROR_CODE: i64 = -32600; +async fn start_thread_with_model( + mcp: &mut TestAppServer, + model: &str, + allow_provider_model_fallback: bool, +) -> Result { + let request_id = mcp + .send_thread_start_request_with_auto_env(ThreadStartParams { + model: Some(model.to_string()), + allow_provider_model_fallback, + ..Default::default() + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await??; + to_response(response) +} + +#[tokio::test] +async fn thread_start_provider_model_fallback_applies_to_configured_model() -> Result<()> { + let codex_home = TempDir::new()?; + std::fs::write( + codex_home.path().join("config.toml"), + r#"model_provider = "amazon-bedrock" +model = "gpt-5.4-mini" +"#, + )?; + let mut mcp = TestAppServer::new_with_auto_env(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let request_id = mcp + .send_thread_start_request_with_auto_env(ThreadStartParams { + allow_provider_model_fallback: true, + ..Default::default() + }) + .await?; + let response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await??; + let response: ThreadStartResponse = to_response(response)?; + + assert_eq!(response.model, "openai.gpt-5.5"); + Ok(()) +} + +#[tokio::test] +async fn thread_start_provider_model_fallback_uses_bedrock_static_catalog() -> Result<()> { + let codex_home = TempDir::new()?; + std::fs::write( + codex_home.path().join("config.toml"), + r#"model_provider = "amazon-bedrock" +"#, + )?; + let mut mcp = TestAppServer::new_with_auto_env(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let unsupported_with_fallback = start_thread_with_model( + &mut mcp, + "gpt-5.4-mini", + /*allow_provider_model_fallback*/ true, + ) + .await?; + let supported_with_fallback = start_thread_with_model( + &mut mcp, + "openai.gpt-5.4", + /*allow_provider_model_fallback*/ true, + ) + .await?; + let unsupported_without_fallback = start_thread_with_model( + &mut mcp, + "gpt-5.4-mini", + /*allow_provider_model_fallback*/ false, + ) + .await?; + + assert_eq!( + vec![ + unsupported_with_fallback.model, + supported_with_fallback.model, + unsupported_without_fallback.model, + ], + vec!["openai.gpt-5.5", "openai.gpt-5.4", "gpt-5.4-mini"] + ); + Ok(()) +} + +#[tokio::test] +async fn thread_start_provider_model_fallback_ignores_dynamic_catalog() -> Result<()> { + let server = create_mock_responses_server_repeating_assistant("Done").await; + let codex_home = TempDir::new()?; + create_config_toml_without_approval_policy(codex_home.path(), &server.uri())?; + let mut mcp = TestAppServer::new_with_auto_env(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let response = start_thread_with_model( + &mut mcp, + "unlisted-dynamic-model", + /*allow_provider_model_fallback*/ true, + ) + .await?; + + assert_eq!(response.model, "unlisted-dynamic-model"); + Ok(()) +} + #[tokio::test] async fn thread_start_creates_thread_and_emits_started() -> Result<()> { // Provide a mock server and config so model wiring is valid. diff --git a/codex-rs/core/src/agent/control_tests.rs b/codex-rs/core/src/agent/control_tests.rs index de8754ad920f..acb7af377991 100644 --- a/codex-rs/core/src/agent/control_tests.rs +++ b/codex-rs/core/src/agent/control_tests.rs @@ -1462,6 +1462,7 @@ async fn spawn_agent_fork_last_n_turns_drops_parent_startup_prefix_when_under_li .manager .start_thread_with_options(StartThreadOptions { config: harness.config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -2250,6 +2251,7 @@ async fn spawn_thread_subagents_persist_parent_originator_across_new_and_truncat .manager .start_thread_with_options(StartThreadOptions { config: harness.config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 6fbd195fce93..21e8de83ca50 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -86,6 +86,7 @@ pub(crate) async fn run_codex_thread_interactive( }; let CodexSpawnOk { codex, .. } = Box::pin(Codex::spawn(CodexSpawnArgs { config, + allow_provider_model_fallback: false, user_instructions, installation_id: parent_session.installation_id.clone(), auth_manager, diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index 94bef646aa31..442722003bee 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -414,6 +414,7 @@ pub struct CodexSpawnOk { pub(crate) struct CodexSpawnArgs { pub(crate) config: Config, + pub(crate) allow_provider_model_fallback: bool, pub(crate) user_instructions: LoadedUserInstructions, pub(crate) installation_id: String, pub(crate) auth_manager: Arc, @@ -503,6 +504,7 @@ impl Codex { async fn spawn_internal(args: CodexSpawnArgs) -> CodexResult { let CodexSpawnArgs { mut config, + allow_provider_model_fallback, user_instructions, installation_id, auth_manager, @@ -576,8 +578,23 @@ impl Codex { let _ = models_manager.list_models(refresh_strategy).await; } let model = models_manager - .get_default_model(&config.model, refresh_strategy) + .get_default_model( + &config.model, + allow_provider_model_fallback, + refresh_strategy, + ) .await; + if allow_provider_model_fallback + && let Some(requested_model) = config.model.as_ref() + && model != *requested_model + { + info!( + model_provider = %config.model_provider_id, + requested_model, + fallback_model = %model, + "replaced unavailable requested model with provider default" + ); + } // Resolve base instructions for the session. Priority order: // 1. config.base_instructions override diff --git a/codex-rs/core/src/session/tests/guardian_tests.rs b/codex-rs/core/src/session/tests/guardian_tests.rs index 64c4ba29f2bd..4f88c9ed6679 100644 --- a/codex-rs/core/src/session/tests/guardian_tests.rs +++ b/codex-rs/core/src/session/tests/guardian_tests.rs @@ -718,6 +718,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { let CodexSpawnOk { codex, .. } = Codex::spawn(CodexSpawnArgs { config, + allow_provider_model_fallback: false, user_instructions: Default::default(), installation_id: "11111111-1111-4111-8111-111111111111".to_string(), auth_manager, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index b2add06024fc..358eae352773 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -182,6 +182,7 @@ pub struct ThreadManager { pub struct StartThreadOptions { pub config: Config, + pub allow_provider_model_fallback: bool, pub initial_history: InitialHistory, pub session_source: Option, pub thread_source: Option, @@ -632,6 +633,7 @@ impl ThreadManager { ); Box::pin(self.start_thread_with_options(StartThreadOptions { config, + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -668,6 +670,7 @@ impl ThreadManager { Box::pin(self.state.spawn_thread_with_source( options.config, options.initial_history, + options.allow_provider_model_fallback, Arc::clone(&self.state.auth_manager), agent_control, session_source, @@ -763,6 +766,7 @@ impl ThreadManager { Box::pin(self.state.spawn_thread_with_source( config, initial_history, + /*allow_provider_model_fallback*/ false, auth_manager, agent_control, session_source, @@ -832,6 +836,7 @@ impl ThreadManager { Box::pin(self.state.spawn_thread_with_source( config, initial_history, + /*allow_provider_model_fallback*/ false, auth_manager, agent_control, session_source, @@ -1332,6 +1337,7 @@ impl ThreadManagerState { Box::pin(self.spawn_thread_with_source( config, InitialHistory::New, + /*allow_provider_model_fallback*/ false, Arc::clone(&self.auth_manager), agent_control, session_source, @@ -1370,6 +1376,7 @@ impl ThreadManagerState { Box::pin(self.spawn_thread_with_source( config, initial_history, + /*allow_provider_model_fallback*/ false, Arc::clone(&self.auth_manager), agent_control, session_source, @@ -1410,6 +1417,7 @@ impl ThreadManagerState { Box::pin(self.spawn_thread_with_source( config, initial_history, + /*allow_provider_model_fallback*/ false, Arc::clone(&self.auth_manager), agent_control, session_source, @@ -1451,6 +1459,7 @@ impl ThreadManagerState { Box::pin(self.spawn_thread_with_source( config, initial_history, + /*allow_provider_model_fallback*/ false, auth_manager, agent_control, self.session_source.clone(), @@ -1475,6 +1484,7 @@ impl ThreadManagerState { &self, config: Config, initial_history: InitialHistory, + allow_provider_model_fallback: bool, auth_manager: Arc, agent_control: AgentControl, session_source: SessionSource, @@ -1541,6 +1551,7 @@ impl ThreadManagerState { codex, thread_id, .. } = Box::pin(Codex::spawn(CodexSpawnArgs { config, + allow_provider_model_fallback, user_instructions, installation_id: self.installation_id.clone(), auth_manager, diff --git a/codex-rs/core/src/thread_manager_tests.rs b/codex-rs/core/src/thread_manager_tests.rs index 0d407e19c5d7..34df66cb6767 100644 --- a/codex-rs/core/src/thread_manager_tests.rs +++ b/codex-rs/core/src/thread_manager_tests.rs @@ -400,6 +400,7 @@ async fn start_thread_keeps_internal_threads_hidden_from_normal_lookups() { let thread = manager .start_thread_with_options(StartThreadOptions { config, + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: Some(SessionSource::Internal( InternalSessionSource::MemoryConsolidation, @@ -542,6 +543,7 @@ async fn start_thread_seeds_extension_data_for_mcp_and_lifecycle_contributors() let first_thread = manager .start_thread_with_options(StartThreadOptions { config: config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -557,6 +559,7 @@ async fn start_thread_seeds_extension_data_for_mcp_and_lifecycle_contributors() let second_thread = manager .start_thread_with_options(StartThreadOptions { config: config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -636,6 +639,7 @@ async fn selected_capability_roots_round_trip_through_fork() { let inherited = manager .start_thread_with_options(StartThreadOptions { config, + allow_provider_model_fallback: false, initial_history: InitialHistory::Forked(vec![RolloutItem::SessionMeta( SessionMetaLine { meta: SessionMeta { @@ -714,6 +718,7 @@ async fn resume_and_fork_do_not_restore_thread_environments_from_rollout() { let source = manager .start_thread_with_options(StartThreadOptions { config: source_config, + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -997,6 +1002,7 @@ async fn resume_stopped_thread_from_rollout_preserves_thread_source() { let source = manager .start_thread_with_options(StartThreadOptions { config: config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: Some(ThreadSource::User), diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index eb4fa3b39f39..6b1126777490 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -658,6 +658,7 @@ impl TestCodexBuilder { Box::pin( thread_manager.start_thread_with_options(StartThreadOptions { config: config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, diff --git a/codex-rs/core/tests/suite/agents_md.rs b/codex-rs/core/tests/suite/agents_md.rs index 605e0ada7dfb..4e0ca794e310 100644 --- a/codex-rs/core/tests/suite/agents_md.rs +++ b/codex-rs/core/tests/suite/agents_md.rs @@ -486,6 +486,7 @@ async fn loads_user_instructions_without_a_primary_environment() -> Result<()> { .thread_manager .start_thread_with_options(StartThreadOptions { config: test.config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, @@ -691,6 +692,7 @@ async fn multi_environment_thread_loads_every_project_and_keeps_creation_snapsho .thread_manager .start_thread_with_options(StartThreadOptions { config: test.config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: None, thread_source: None, diff --git a/codex-rs/core/tests/suite/remote_models.rs b/codex-rs/core/tests/suite/remote_models.rs index b63aff5400c8..63d891bada8a 100644 --- a/codex-rs/core/tests/suite/remote_models.rs +++ b/codex-rs/core/tests/suite/remote_models.rs @@ -1059,7 +1059,11 @@ async fn remote_models_request_times_out_after_5s() -> Result<()> { let start = Instant::now(); let model = timeout( Duration::from_secs(7), - manager.get_default_model(&None, RefreshStrategy::OnlineIfUncached), + manager.get_default_model( + &None, + /*allow_provider_model_fallback*/ false, + RefreshStrategy::OnlineIfUncached, + ), ) .await; let elapsed = start.elapsed(); @@ -1127,7 +1131,11 @@ async fn remote_models_hide_picker_only_models() -> Result<()> { ); let selected = manager - .get_default_model(&None, RefreshStrategy::OnlineIfUncached) + .get_default_model( + &None, + /*allow_provider_model_fallback*/ false, + RefreshStrategy::OnlineIfUncached, + ) .await; assert_eq!(selected, bundled_default_model_slug()); diff --git a/codex-rs/core/tests/suite/subagent_notifications.rs b/codex-rs/core/tests/suite/subagent_notifications.rs index aad8fc494bb9..62a367a229d6 100644 --- a/codex-rs/core/tests/suite/subagent_notifications.rs +++ b/codex-rs/core/tests/suite/subagent_notifications.rs @@ -761,6 +761,7 @@ async fn subagent_stop_replaces_stop_and_skips_internal_subagents() -> Result<() .thread_manager .start_thread_with_options(StartThreadOptions { config: test.config.clone(), + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: Some(SessionSource::SubAgent(SubAgentSource::Review)), thread_source: None, diff --git a/codex-rs/memories/write/src/runtime.rs b/codex-rs/memories/write/src/runtime.rs index 0ee0313b5845..73744e0a89d1 100644 --- a/codex-rs/memories/write/src/runtime.rs +++ b/codex-rs/memories/write/src/runtime.rs @@ -329,6 +329,7 @@ impl MemoryStartupContext { .thread_manager .start_thread_with_options(StartThreadOptions { config, + allow_provider_model_fallback: false, initial_history: InitialHistory::New, session_source: Some(SessionSource::Internal( InternalSessionSource::MemoryConsolidation, diff --git a/codex-rs/models-manager/src/manager.rs b/codex-rs/models-manager/src/manager.rs index 7a68c254fee7..67b939051b6a 100644 --- a/codex-rs/models-manager/src/manager.rs +++ b/codex-rs/models-manager/src/manager.rs @@ -144,11 +144,13 @@ pub trait ModelsManager: fmt::Debug + Send + Sync { // todo(aibrahim): should be visible to core only and sent on session_configured event /// Get the model identifier to use, refreshing according to the specified strategy. /// - /// If `model` is provided, returns it directly. Otherwise selects the default based on - /// auth mode and available models. + /// If `model` is provided, preserves it unless the implementation supports and the policy + /// allows provider fallback. Otherwise selects the default based on auth mode and available + /// models. fn get_default_model<'a>( &'a self, model: &'a Option, + allow_provider_model_fallback: bool, refresh_strategy: RefreshStrategy, ) -> ModelsManagerFuture<'a, String> { Box::pin( @@ -161,6 +163,7 @@ pub trait ModelsManager: fmt::Debug + Send + Sync { .instrument(tracing::info_span!( "get_default_model", model.provided = model.is_some(), + allow_provider_model_fallback, refresh_strategy = %refresh_strategy )), ) @@ -408,6 +411,39 @@ impl OpenAiModelsManager { } impl ModelsManager for StaticModelsManager { + fn get_default_model<'a>( + &'a self, + model: &'a Option, + allow_provider_model_fallback: bool, + refresh_strategy: RefreshStrategy, + ) -> ModelsManagerFuture<'a, String> { + Box::pin( + async move { + let available_models = self.list_models(refresh_strategy).await; + let requested_model = model.as_deref(); + + if allow_provider_model_fallback { + if requested_model_is_available(requested_model, &available_models) + && let Some(requested_model) = requested_model + { + return requested_model.to_string(); + } + return default_model_from_available(available_models); + } + + model + .clone() + .unwrap_or_else(|| default_model_from_available(available_models)) + } + .instrument(tracing::info_span!( + "get_default_model", + model.provided = model.is_some(), + allow_provider_model_fallback, + refresh_strategy = %refresh_strategy + )), + ) + } + fn raw_model_catalog( &self, _refresh_strategy: RefreshStrategy, @@ -453,6 +489,17 @@ fn default_model_from_available(available: Vec) -> String { .unwrap_or_default() } +fn requested_model_is_available( + requested_model: Option<&str>, + available_models: &[ModelPreset], +) -> bool { + requested_model.is_some_and(|requested_model| { + available_models + .iter() + .any(|available_model| available_model.model == requested_model) + }) +} + fn find_model_by_longest_prefix(model: &str, candidates: &[ModelInfo]) -> Option { let mut best: Option = None; for candidate in candidates { diff --git a/codex-rs/models-manager/src/manager_tests.rs b/codex-rs/models-manager/src/manager_tests.rs index bd692baad116..e7d1691a3a20 100644 --- a/codex-rs/models-manager/src/manager_tests.rs +++ b/codex-rs/models-manager/src/manager_tests.rs @@ -239,6 +239,105 @@ c2ln", .expect("auth should be present") } +#[tokio::test] +async fn static_manager_preserves_supported_requested_model_when_fallback_is_allowed() { + let manager = static_manager_for_tests(ModelsResponse { + models: vec![ + remote_model("provider-default", "Default", /*priority*/ 0), + remote_model("provider-supported", "Supported", /*priority*/ 1), + ], + }); + let requested_model = Some("provider-supported".to_string()); + + let model = manager + .get_default_model( + &requested_model, + /*allow_provider_model_fallback*/ true, + RefreshStrategy::Offline, + ) + .await; + + assert_eq!(model, "provider-supported"); +} + +#[tokio::test] +async fn static_manager_falls_back_from_unsupported_requested_model_when_allowed() { + let manager = static_manager_for_tests(ModelsResponse { + models: vec![ + remote_model("provider-default", "Default", /*priority*/ 0), + remote_model("provider-supported", "Supported", /*priority*/ 1), + ], + }); + let requested_model = Some("unsupported".to_string()); + + let model = manager + .get_default_model( + &requested_model, + /*allow_provider_model_fallback*/ true, + RefreshStrategy::Offline, + ) + .await; + + assert_eq!(model, "provider-default"); +} + +#[tokio::test] +async fn static_manager_preserves_unsupported_requested_model_when_fallback_is_disabled() { + let manager = static_manager_for_tests(ModelsResponse { + models: vec![remote_model( + "provider-default", + "Default", + /*priority*/ 0, + )], + }); + let requested_model = Some("unsupported".to_string()); + + let model = manager + .get_default_model( + &requested_model, + /*allow_provider_model_fallback*/ false, + RefreshStrategy::Offline, + ) + .await; + + assert_eq!(model, "unsupported"); +} + +#[tokio::test] +async fn static_manager_uses_empty_default_when_fallback_is_allowed_and_catalog_is_empty() { + let manager = static_manager_for_tests(ModelsResponse { models: Vec::new() }); + let requested_model = Some("unsupported".to_string()); + + let model = manager + .get_default_model( + &requested_model, + /*allow_provider_model_fallback*/ true, + RefreshStrategy::Offline, + ) + .await; + + assert_eq!(model, ""); +} + +#[tokio::test] +async fn dynamic_manager_preserves_requested_model_when_fallback_is_allowed() { + let codex_home = tempdir().expect("temp dir"); + let endpoint = TestModelsEndpoint::new(Vec::new()); + let manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); + let requested_model = Some("unsupported".to_string()); + + let model = manager + .get_default_model( + &requested_model, + /*allow_provider_model_fallback*/ true, + RefreshStrategy::Online, + ) + .await; + + assert_eq!(model, "unsupported"); + assert_eq!(endpoint.fetch_count(), 0); +} + #[tokio::test] async fn get_model_info_tracks_fallback_usage() { let codex_home = tempdir().expect("temp dir");