From 9e0a174c39d1106acbcbc0e25879444093291fd5 Mon Sep 17 00:00:00 2001 From: Fernando Correa Neto Date: Wed, 6 May 2026 18:01:32 -0300 Subject: [PATCH] Add Vertex AI provider support for Claude models Route Claude requests through Google Cloud Vertex AI when ANTHROPIC_VERTEX_PROJECT_ID and CLOUD_ML_REGION (or ANTHROPIC_VERTEX_REGION) are set. Authentication via Google Application Default Credentials: - GCE/Cloud Run metadata server (automatic on GCP) - authorized_user credentials (gcloud auth application-default login) - service_account credentials (JSON key file) Service account JWT signing uses ring (already a transitive dep via rustls) - pure Rust, no subprocess or openssl dependency required. Vertex-specific API differences handled: - model stripped from request body (encoded in the URL instead) - anthropic_version sent in body as 'vertex-2023-10-16' (not a header) - anthropic-beta header omitted; prompt caching works natively via cache_control blocks in the request body --- Cargo.toml | 1 + src/provider/anthropic.rs | 553 ++++++++++++++++++++++++++---- src/provider/mod.rs | 30 +- src/provider/route_builders.rs | 11 + src/provider/startup.rs | 11 +- src/tui/app/inline_interactive.rs | 53 ++- 6 files changed, 575 insertions(+), 84 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c965ae69e..f1b1babe4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,6 +140,7 @@ jcode-import-core = { path = "crates/jcode-import-core" } # OAuth base64 = "0.22" sha2 = "0.10" +ring = "0.17" rand = "0.9.3" hex = "0.4" url = "2" diff --git a/src/provider/anthropic.rs b/src/provider/anthropic.rs index 78a877cf4..6aba8c476 100644 --- a/src/provider/anthropic.rs +++ b/src/provider/anthropic.rs @@ -370,6 +370,8 @@ const DEFAULT_MODEL: &str = "claude-opus-4-6"; /// API version header const API_VERSION: &str = "2023-06-01"; +/// API version required by Vertex AI in the request body +const VERTEX_API_VERSION: &str = "vertex-2023-10-16"; /// Claude Agent SDK identity block observed in the official Claude Code client. const CLAUDE_CODE_IDENTITY: &str = "You are a Claude agent, built on Anthropic's Claude Agent SDK."; @@ -405,12 +407,322 @@ struct CachedCredentials { expires_at: i64, } +/// Cached Google ADC access token for Vertex AI mode +#[derive(Clone)] +struct GoogleToken { + access_token: String, + expires_at: i64, +} + +/// How the request is authenticated and which endpoint to use. +#[derive(Clone)] +enum AuthMode { + /// Direct API key via x-api-key header + ApiKey, + /// Claude.ai OAuth Bearer token + OAuth, + /// Google ADC Bearer token routed through Vertex AI + Vertex { url: String }, +} + +// --- Vertex AI support --- + +const VERTEX_PROJECT_ENV: &str = "ANTHROPIC_VERTEX_PROJECT_ID"; +const VERTEX_REGION_ENV: &str = "CLOUD_ML_REGION"; +const VERTEX_REGION_ENV_ALT: &str = "ANTHROPIC_VERTEX_REGION"; + +struct VertexConfig { + project_id: String, + region: String, +} + +fn vertex_config() -> Option { + let project_id = std::env::var(VERTEX_PROJECT_ENV).ok()?; + let project_id = project_id.trim().to_string(); + if project_id.is_empty() { + return None; + } + + let region = std::env::var(VERTEX_REGION_ENV) + .or_else(|_| std::env::var(VERTEX_REGION_ENV_ALT)) + .ok()?; + let region = region.trim().to_string(); + if region.is_empty() { + return None; + } + + Some(VertexConfig { project_id, region }) +} + +fn vertex_endpoint(config: &VertexConfig, model: &str) -> String { + let VertexConfig { project_id, region } = config; + let base = if region == "global" { + "https://aiplatform.googleapis.com".to_string() + } else { + format!("https://{region}-aiplatform.googleapis.com") + }; + format!( + "{base}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{model}:streamRawPredict" + ) +} + +// --- Google Application Default Credentials --- + +#[derive(Deserialize)] +struct AdcFile { + #[serde(rename = "type")] + cred_type: String, + // authorized_user fields + client_id: Option, + client_secret: Option, + refresh_token: Option, + // service_account fields + client_email: Option, + private_key: Option, + token_uri: Option, +} + +#[derive(Deserialize)] +struct GoogleTokenResponse { + access_token: String, + expires_in: u64, +} + +async fn fetch_google_adc_token(client: &Client) -> Result { + // 1. GCE/Cloud Run metadata server + let meta_url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"; + if let Ok(resp) = client + .get(meta_url) + .header("Metadata-Flavor", "Google") + .timeout(std::time::Duration::from_secs(1)) + .send() + .await + { + if resp.status().is_success() { + if let Ok(token) = resp.json::().await { + return Ok(token); + } + } + } + + // 2. Credential file (GOOGLE_APPLICATION_CREDENTIALS or well-known ADC path) + let cred_path = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") + .ok() + .or_else(|| { + let home = std::env::var("HOME").ok()?; + let path = format!("{home}/.config/gcloud/application_default_credentials.json"); + std::path::Path::new(&path).exists().then_some(path) + }); + + let path = cred_path.ok_or_else(|| { + anyhow::anyhow!( + "No Google credentials found. Run `gcloud auth application-default login` or set GOOGLE_APPLICATION_CREDENTIALS." + ) + })?; + + let raw = tokio::fs::read_to_string(&path) + .await + .with_context(|| format!("Failed to read credentials file: {path}"))?; + let adc: AdcFile = serde_json::from_str(&raw) + .with_context(|| format!("Failed to parse credentials file: {path}"))?; + + match adc.cred_type.as_str() { + "authorized_user" => { + let client_id = adc.client_id.context("missing client_id in ADC file")?; + let client_secret = adc + .client_secret + .context("missing client_secret in ADC file")?; + let refresh_token = adc + .refresh_token + .context("missing refresh_token in ADC file")?; + let resp = client + .post("https://oauth2.googleapis.com/token") + .form(&[ + ("client_id", client_id.as_str()), + ("client_secret", client_secret.as_str()), + ("refresh_token", refresh_token.as_str()), + ("grant_type", "refresh_token"), + ]) + .send() + .await + .context("Failed to refresh Google ADC token")?; + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("Google token refresh failed ({}): {}", status, body); + } + resp.json::() + .await + .context("Failed to parse Google token response") + } + "service_account" => { + let email = adc + .client_email + .context("missing client_email in service account")?; + let private_key = adc + .private_key + .context("missing private_key in service account")?; + let token_uri = adc + .token_uri + .unwrap_or_else(|| "https://oauth2.googleapis.com/token".to_string()); + fetch_service_account_token(client, &email, &private_key, &token_uri).await + } + other => anyhow::bail!("Unsupported ADC credential type: {other}"), + } +} + +/// Mint a short-lived access token from a service account key using a signed JWT. +async fn fetch_service_account_token( + client: &Client, + email: &str, + private_key_pem: &str, + token_uri: &str, +) -> Result { + use std::time::{SystemTime, UNIX_EPOCH}; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let header = base64_url_encode(br#"{"alg":"RS256","typ":"JWT"}"#); + let claim = serde_json::json!({ + "iss": email, + "scope": "https://www.googleapis.com/auth/cloud-platform", + "aud": token_uri, + "iat": now, + "exp": now + 3600, + }); + let claim_bytes = + serde_json::to_vec(&claim).context("Failed to serialize service account JWT claims")?; + let claim_b64 = base64_url_encode(&claim_bytes); + + let signing_input = format!("{header}.{claim_b64}"); + let signature = rs256_sign(private_key_pem, signing_input.as_bytes()) + .context("Failed to sign service account JWT (ensure the private_key is valid RSA PEM)")?; + let sig_b64 = base64_url_encode(&signature); + + let jwt = format!("{signing_input}.{sig_b64}"); + + let resp = client + .post(token_uri) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", jwt.as_str()), + ]) + .send() + .await + .context("Failed to exchange service account JWT for access token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!( + "Service account token exchange failed ({}): {}", + status, + body + ); + } + resp.json::() + .await + .context("Failed to parse service account token response") +} + +fn base64_url_encode(input: &[u8]) -> String { + let b64 = encode_base64_standard(input); + b64.replace('+', "-").replace('/', "_").replace('=', "") +} + +fn encode_base64_standard(input: &[u8]) -> String { + const TABLE: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut out = String::with_capacity((input.len() + 2) / 3 * 4); + for chunk in input.chunks(3) { + let b0 = chunk[0] as u32; + let b1 = chunk.get(1).copied().unwrap_or(0) as u32; + let b2 = chunk.get(2).copied().unwrap_or(0) as u32; + let n = (b0 << 16) | (b1 << 8) | b2; + out.push(TABLE[((n >> 18) & 0x3f) as usize] as char); + out.push(TABLE[((n >> 12) & 0x3f) as usize] as char); + if chunk.len() > 1 { + out.push(TABLE[((n >> 6) & 0x3f) as usize] as char); + } else { + out.push('='); + } + if chunk.len() > 2 { + out.push(TABLE[(n & 0x3f) as usize] as char); + } else { + out.push('='); + } + } + out +} + +/// RS256 signing for service account JWTs using `ring` (pure Rust, no subprocess needed). +/// Accepts PKCS#8 DER keys (produced by gcloud) with a PKCS#1 fallback. +fn rs256_sign(pem: &str, message: &[u8]) -> Result> { + // Strip PEM header/footer and base64-decode to get the DER bytes. + let b64 = pem + .lines() + .filter(|l| !l.starts_with("-----")) + .collect::(); + let der = base64_decode_standard(&b64).context("Failed to decode PEM private key")?; + + // ring accepts PKCS#8 DER (the format produced by gcloud / most GCP service account keys). + // Fall back to the raw PKCS#1 DER format just in case. + let key_pair = ring::signature::RsaKeyPair::from_pkcs8(&der) + .or_else(|_| ring::signature::RsaKeyPair::from_der(&der)) + .map_err(|e| anyhow::anyhow!("Invalid RSA private key: {e:?}"))?; + + let rng = ring::rand::SystemRandom::new(); + let mut sig = vec![0u8; key_pair.public().modulus_len()]; + key_pair + .sign(&ring::signature::RSA_PKCS1_SHA256, &rng, message, &mut sig) + .map_err(|e| anyhow::anyhow!("RSA signing failed: {e:?}"))?; + Ok(sig) +} + +fn base64_decode_standard(input: &str) -> Result> { + const TABLE: [u8; 256] = { + let mut t = [255u8; 256]; + let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut i = 0usize; + while i < chars.len() { + t[chars[i] as usize] = i as u8; + i += 1; + } + t + }; + let input = input.trim().as_bytes(); + let mut out = Vec::with_capacity(input.len() / 4 * 3); + let mut i = 0; + while i + 3 < input.len() { + let a = TABLE[input[i] as usize]; + let b = TABLE[input[i + 1] as usize]; + let c = TABLE[input[i + 2] as usize]; + let d = TABLE[input[i + 3] as usize]; + if a == 255 || b == 255 { + break; + } + out.push((a << 2) | (b >> 4)); + if c != 255 { + out.push((b << 4) | (c >> 2)); + } + if d != 255 { + out.push((c << 6) | d); + } + i += 4; + } + Ok(out) +} + /// Direct Anthropic API provider pub struct AnthropicProvider { client: Client, model: Arc>, /// Cached OAuth credentials (None if using API key) credentials: Arc>>, + /// Cached Google ADC token for Vertex AI mode + google_token: Arc>>, max_tokens: u32, oauth_session_id: String, oauth_preflight_done: Arc, @@ -447,12 +759,36 @@ impl AnthropicProvider { client: crate::provider::shared_http_client(), model: Arc::new(std::sync::RwLock::new(model)), credentials: Arc::new(RwLock::new(None)), + google_token: Arc::new(RwLock::new(None)), max_tokens, oauth_session_id: Uuid::new_v4().to_string(), oauth_preflight_done: Arc::new(AtomicBool::new(false)), } } + /// Get a Google ADC access token, with in-memory caching. + async fn get_google_token(&self) -> Result { + { + let cached = self.google_token.read().await; + if let Some(ref t) = *cached { + let now = chrono::Utc::now().timestamp_millis(); + if t.expires_at > now + 60_000 { + return Ok(t.access_token.clone()); + } + } + } + + let token = fetch_google_adc_token(&self.client).await?; + let expires_at = chrono::Utc::now().timestamp_millis() + token.expires_in as i64 * 1000; + let access_token = token.access_token.clone(); + let mut cached = self.google_token.write().await; + *cached = Some(GoogleToken { + access_token: access_token.clone(), + expires_at, + }); + Ok(access_token) + } + /// Get the access token from credentials /// Supports both OAuth tokens and direct API keys /// Automatically refreshes OAuth tokens when expired @@ -872,16 +1208,38 @@ impl Provider for AnthropicProvider { system: &str, _resume_session_id: Option<&str>, ) -> Result { - let (token, is_oauth) = self.get_access_token().await?; - if is_oauth { - ensure_oauth_preflight( - &self.client, - &token, - &self.oauth_session_id, - &self.oauth_preflight_done, - ) - .await?; - } + // Vertex AI mode: detected via env vars, same as Claude Code SDK. + let (token, auth_mode) = if let Some(vcfg) = vertex_config() { + let google_token = self.get_google_token().await?; + let model = self + .model + .read() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .clone(); + let api_model = strip_1m_suffix(&model).to_string(); + let url = vertex_endpoint(&vcfg, &api_model); + (google_token, AuthMode::Vertex { url }) + } else { + let (token, is_oauth) = self.get_access_token().await?; + if is_oauth { + ensure_oauth_preflight( + &self.client, + &token, + &self.oauth_session_id, + &self.oauth_preflight_done, + ) + .await?; + } + let mode = if is_oauth { + AuthMode::OAuth + } else { + AuthMode::ApiKey + }; + (token, mode) + }; + + let is_oauth = matches!(auth_mode, AuthMode::OAuth); + let model = self .model .read() @@ -893,6 +1251,7 @@ impl Provider for AnthropicProvider { let api_messages = self.format_messages(messages, is_oauth); let api_tools = self.format_tools(tools, is_oauth); + let is_vertex = matches!(auth_mode, AuthMode::Vertex { .. }); let request = ApiRequest { model: api_model, max_tokens: self.max_tokens, @@ -910,11 +1269,21 @@ impl Provider for AnthropicProvider { }, temperature: if is_oauth { Some(1.0) } else { None }, stream: true, + anthropic_version: if is_vertex { + Some(VERTEX_API_VERSION) + } else { + None + }, }; + let transport_label = match &auth_mode { + AuthMode::Vertex { .. } => "vertex/sse", + AuthMode::OAuth => "oauth/sse", + AuthMode::ApiKey => "apikey/sse", + }; crate::logging::info(&format!( - "Anthropic transport: HTTPS SSE stream (oauth={})", - is_oauth + "Anthropic transport: HTTPS SSE stream ({})", + transport_label )); // Create channel for streaming events @@ -940,7 +1309,7 @@ impl Provider for AnthropicProvider { run_stream_with_retries( client, token, - is_oauth, + auth_mode, request, tx, credentials, @@ -1035,6 +1404,7 @@ impl Provider for AnthropicProvider { .clone(), )), credentials: Arc::new(RwLock::new(None)), + google_token: Arc::new(RwLock::new(None)), max_tokens: self.max_tokens, oauth_session_id: self.oauth_session_id.clone(), oauth_preflight_done: Arc::new(AtomicBool::new( @@ -1062,16 +1432,37 @@ impl Provider for AnthropicProvider { system_dynamic: &str, _resume_session_id: Option<&str>, ) -> Result { - let (token, is_oauth) = self.get_access_token().await?; - if is_oauth { - ensure_oauth_preflight( - &self.client, - &token, - &self.oauth_session_id, - &self.oauth_preflight_done, - ) - .await?; - } + let (token, auth_mode) = if let Some(vcfg) = vertex_config() { + let google_token = self.get_google_token().await?; + let model = self + .model + .read() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .clone(); + let api_model = strip_1m_suffix(&model).to_string(); + let url = vertex_endpoint(&vcfg, &api_model); + (google_token, AuthMode::Vertex { url }) + } else { + let (token, is_oauth) = self.get_access_token().await?; + if is_oauth { + ensure_oauth_preflight( + &self.client, + &token, + &self.oauth_session_id, + &self.oauth_preflight_done, + ) + .await?; + } + let mode = if is_oauth { + AuthMode::OAuth + } else { + AuthMode::ApiKey + }; + (token, mode) + }; + + let is_oauth = matches!(auth_mode, AuthMode::OAuth); + let model = self .model .read() @@ -1083,6 +1474,7 @@ impl Provider for AnthropicProvider { let api_messages = self.format_messages(messages, is_oauth); let api_tools = self.format_tools(tools, is_oauth); + let is_vertex = matches!(auth_mode, AuthMode::Vertex { .. }); let request = ApiRequest { model: api_model, max_tokens: self.max_tokens, @@ -1100,11 +1492,21 @@ impl Provider for AnthropicProvider { }, temperature: if is_oauth { Some(1.0) } else { None }, stream: true, + anthropic_version: if is_vertex { + Some(VERTEX_API_VERSION) + } else { + None + }, }; + let transport_label = match &auth_mode { + AuthMode::Vertex { .. } => "vertex/sse", + AuthMode::OAuth => "oauth/sse", + AuthMode::ApiKey => "apikey/sse", + }; crate::logging::info(&format!( - "Anthropic transport: HTTPS SSE split stream (oauth={})", - is_oauth + "Anthropic transport: HTTPS SSE split stream ({})", + transport_label )); // Create channel for streaming events @@ -1129,7 +1531,7 @@ impl Provider for AnthropicProvider { run_stream_with_retries( client, token, - is_oauth, + auth_mode, request, tx, credentials, @@ -1150,7 +1552,7 @@ impl Provider for AnthropicProvider { async fn run_stream_with_retries( client: Client, initial_token: String, - is_oauth: bool, + auth_mode: AuthMode, request: ApiRequest, tx: mpsc::Sender>, credentials: Arc>>, @@ -1160,6 +1562,7 @@ async fn run_stream_with_retries( let mut token = initial_token; let mut last_error = None; let mut attempted_forced_refresh = false; + let is_oauth = matches!(auth_mode, AuthMode::OAuth); for attempt in 0..MAX_RETRIES { if attempt > 0 { @@ -1184,7 +1587,7 @@ async fn run_stream_with_retries( match stream_response( client.clone(), token.clone(), - is_oauth, + auth_mode.clone(), request.clone(), tx.clone(), &model_name, @@ -1312,7 +1715,7 @@ async fn force_refresh_oauth_token( async fn stream_response( client: Client, token: String, - is_oauth: bool, + auth_mode: AuthMode, request: ApiRequest, tx: mpsc::Sender>, model_name: &str, @@ -1334,12 +1737,15 @@ async fn stream_response( .await; let connect_start = std::time::Instant::now(); - // Build request with appropriate auth headers - let url = if is_oauth { API_URL_OAUTH } else { API_URL }; + + let (url, is_oauth) = match &auth_mode { + AuthMode::ApiKey => (API_URL.to_string(), false), + AuthMode::OAuth => (API_URL_OAUTH.to_string(), true), + AuthMode::Vertex { url } => (url.clone(), false), + }; let mut req = client - .post(url) - .header("anthropic-version", API_VERSION) + .post(&url) .header("content-type", "application/json") .header( "accept", @@ -1350,36 +1756,58 @@ async fn stream_response( }, ); - if is_oauth { - // OAuth tokens require: - // 1. Bearer auth (NOT x-api-key) - // 2. User-Agent matching Claude CLI - // 3. Multiple beta headers - // 4. ?beta=true query param (in URL above) - req = apply_oauth_attribution_headers( - req.header("Authorization", format!("Bearer {}", token)) - .header("User-Agent", CLAUDE_CLI_USER_AGENT) - .header("anthropic-beta", oauth_beta_headers(model_name)), - oauth_session_id, - ); - } else { - // Direct API keys use x-api-key - // Include prompt-caching beta header - req = req.header("x-api-key", &token).header( - "anthropic-beta", - if is_1m_model(model_name) { - "prompt-caching-2024-07-31,context-1m-2025-08-07" - } else { - "prompt-caching-2024-07-31" - }, - ); + match &auth_mode { + AuthMode::OAuth => { + // OAuth tokens require: + // 1. Bearer auth (NOT x-api-key) + // 2. User-Agent matching Claude CLI + // 3. Multiple beta headers + // 4. ?beta=true query param (in URL above) + req = apply_oauth_attribution_headers( + req.header("anthropic-version", API_VERSION) + .header("Authorization", format!("Bearer {}", token)) + .header("User-Agent", CLAUDE_CLI_USER_AGENT) + .header("anthropic-beta", oauth_beta_headers(model_name)), + oauth_session_id, + ); + } + AuthMode::Vertex { .. } => { + // Vertex AI: Google Bearer token only — version goes in request body + req = req.header("Authorization", format!("Bearer {}", token)); + } + AuthMode::ApiKey => { + // Direct API keys use x-api-key + req = req + .header("anthropic-version", API_VERSION) + .header("x-api-key", &token) + .header( + "anthropic-beta", + if is_1m_model(model_name) { + "prompt-caching-2024-07-31,context-1m-2025-08-07" + } else { + "prompt-caching-2024-07-31" + }, + ); + } } - let response = req - .json(&request) - .send() - .await - .context("Failed to send request to Anthropic API")?; + let response = if matches!(auth_mode, AuthMode::Vertex { .. }) { + // Vertex AI: model is in the URL, strip it from the body + let mut body = + serde_json::to_value(&request).context("Failed to serialize Vertex AI request")?; + if let Some(obj) = body.as_object_mut() { + obj.remove("model"); + } + req.json(&body) + .send() + .await + .context("Failed to send request to Vertex AI")? + } else { + req.json(&request) + .send() + .await + .context("Failed to send request to Anthropic API")? + }; let connect_ms = connect_start.elapsed().as_millis(); crate::logging::info(&format!( @@ -1653,6 +2081,9 @@ struct ApiRequest { #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, stream: bool, + /// Required by Vertex AI in the request body (not a header) + #[serde(skip_serializing_if = "Option::is_none")] + anthropic_version: Option<&'static str>, } #[derive(Serialize, Clone)] diff --git a/src/provider/mod.rs b/src/provider/mod.rs index da5e99c86..e968b94e0 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -44,10 +44,10 @@ pub use jcode_provider_core::{ }; pub(crate) use jcode_provider_core::{ProviderFailoverPrompt, parse_failover_prompt_message}; pub use route_builders::{ - build_anthropic_oauth_route, build_copilot_route, build_openai_api_key_route, - build_openai_oauth_route, build_openrouter_auto_route, build_openrouter_endpoint_route, - build_openrouter_fallback_provider_route, is_listable_model_name, - listable_model_names_from_routes, openrouter_catalog_model_id, + build_anthropic_oauth_route, build_anthropic_vertex_route, build_copilot_route, + build_openai_api_key_route, build_openai_oauth_route, build_openrouter_auto_route, + build_openrouter_endpoint_route, build_openrouter_fallback_provider_route, + is_listable_model_name, listable_model_names_from_routes, openrouter_catalog_model_id, }; pub(crate) use routing::{ anthropic_api_key_route_availability, anthropic_oauth_route_availability, @@ -794,6 +794,15 @@ impl Provider for MultiProvider { let mut openrouter_scheduled_endpoint_refreshes = 0usize; let has_oauth = self.has_claude_runtime(); let has_api_key = std::env::var("ANTHROPIC_API_KEY").is_ok(); + let has_vertex = std::env::var("ANTHROPIC_VERTEX_PROJECT_ID") + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false) + && std::env::var("CLOUD_ML_REGION") + .or_else(|_| std::env::var("ANTHROPIC_VERTEX_REGION")) + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false); let anthropic_models = if let Some(anthropic) = self.anthropic_provider() { anthropic.available_models_for_switching() } else if let Some(claude) = self.claude_provider() { @@ -807,7 +816,7 @@ impl Provider for MultiProvider { known_openai_model_ids() }; - // Anthropic models (oauth and/or api-key) + // Anthropic models (oauth and/or api-key and/or vertex) for model in anthropic_models { let (available, detail) = if has_oauth && !has_api_key { anthropic_oauth_route_availability(&model) @@ -815,6 +824,9 @@ impl Provider for MultiProvider { (true, String::new()) }; + if has_vertex { + routes.push(build_anthropic_vertex_route(&model)); + } if has_oauth { routes.push(build_anthropic_oauth_route( &model, @@ -833,7 +845,7 @@ impl Provider for MultiProvider { cheapness: cheapness_for_route(&model, "Anthropic", "api-key"), }); } - if !has_oauth && !has_api_key { + if !has_vertex && !has_oauth && !has_api_key { routes.push(ModelRoute { model: model.to_string(), provider: "Anthropic".to_string(), @@ -1173,7 +1185,11 @@ impl Provider for MultiProvider { Some(Arc::new(claude::ClaudeProvider::new())); } } else if self.anthropic_provider().is_none() - && crate::auth::claude::load_credentials().is_ok() + && (crate::auth::claude::load_credentials().is_ok() + || std::env::var("ANTHROPIC_VERTEX_PROJECT_ID") + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false)) { crate::logging::info("Hot-initialized Anthropic provider after auth change"); *self diff --git a/src/provider/route_builders.rs b/src/provider/route_builders.rs index bce975302..edb31eba1 100644 --- a/src/provider/route_builders.rs +++ b/src/provider/route_builders.rs @@ -82,6 +82,17 @@ fn build_openai_route( } } +pub fn build_anthropic_vertex_route(model: &str) -> ModelRoute { + ModelRoute { + model: model.to_string(), + provider: "Vertex AI".to_string(), + api_method: "vertex".to_string(), + available: true, + detail: String::new(), + cheapness: cheapness_for_route(model, "Anthropic", "vertex"), + } +} + pub fn build_copilot_route(model: &str, available: bool, detail: impl Into) -> ModelRoute { ModelRoute { model: model.to_string(), diff --git a/src/provider/startup.rs b/src/provider/startup.rs index f49106ea1..7cc46c8f7 100644 --- a/src/provider/startup.rs +++ b/src/provider/startup.rs @@ -77,6 +77,15 @@ impl MultiProvider { } let has_claude_creds = auth::claude::load_credentials().is_ok(); + let has_vertex = std::env::var("ANTHROPIC_VERTEX_PROJECT_ID") + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false) + && std::env::var("CLOUD_ML_REGION") + .or_else(|_| std::env::var("ANTHROPIC_VERTEX_REGION")) + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false); let has_openai_creds = auth::codex::load_credentials().is_ok(); let has_copilot_api = auth_status.copilot_has_api_token; let has_antigravity_creds = auth::antigravity::load_tokens().is_ok(); @@ -102,7 +111,7 @@ impl MultiProvider { None }; - let anthropic = if has_claude_creds && !use_claude_cli { + let anthropic = if (has_claude_creds || has_vertex) && !use_claude_cli { Some(Arc::new(anthropic::AnthropicProvider::new())) } else { None diff --git a/src/tui/app/inline_interactive.rs b/src/tui/app/inline_interactive.rs index 820bcb269..94e133c41 100644 --- a/src/tui/app/inline_interactive.rs +++ b/src/tui/app/inline_interactive.rs @@ -178,12 +178,27 @@ impl App { ) } else { match crate::provider::provider_for_model(&model) { - Some("claude") => ( - "Anthropic".to_string(), - "claude-oauth".to_string(), - auth.anthropic.has_oauth || auth.anthropic.has_api_key, - String::new(), - ), + Some("claude") => { + let has_vertex = std::env::var("ANTHROPIC_VERTEX_PROJECT_ID") + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false); + if has_vertex { + ( + "Vertex AI".to_string(), + "vertex".to_string(), + true, + String::new(), + ) + } else { + ( + "Anthropic".to_string(), + "claude-oauth".to_string(), + auth.anthropic.has_oauth || auth.anthropic.has_api_key, + String::new(), + ) + } + } Some("openai") => unreachable!("OpenAI models are handled above"), Some("gemini") => ( "Gemini".to_string(), @@ -909,15 +924,23 @@ impl App { let mut added_any = false; - if crate::provider::provider_for_model(model) == Some("claude") - && auth.anthropic.has_oauth - { - let (available, detail) = - crate::provider::anthropic_oauth_route_availability(model); - routes.push(crate::provider::build_anthropic_oauth_route( - model, available, detail, - )); - added_any = true; + if crate::provider::provider_for_model(model) == Some("claude") { + let has_vertex = std::env::var("ANTHROPIC_VERTEX_PROJECT_ID") + .ok() + .map(|v| !v.trim().is_empty()) + .unwrap_or(false); + if has_vertex { + routes.push(crate::provider::build_anthropic_vertex_route(model)); + added_any = true; + } + if auth.anthropic.has_oauth { + let (available, detail) = + crate::provider::anthropic_oauth_route_availability(model); + routes.push(crate::provider::build_anthropic_oauth_route( + model, available, detail, + )); + added_any = true; + } } if crate::provider::ALL_OPENAI_MODELS.contains(&model.as_str()) {