diff --git a/crates/goose/src/providers/catalog.rs b/crates/goose/src/providers/catalog.rs index 2cfe698de5d5..9dda84e58444 100644 --- a/crates/goose/src/providers/catalog.rs +++ b/crates/goose/src/providers/catalog.rs @@ -593,6 +593,23 @@ const SETUP_METADATA: &[CuratedSetupMetadata] = &[ secret_field_default: Some(API_KEY_FIELD), field_overrides: &[], }, + CuratedSetupMetadata { + provider_id: "xai_oauth", + category: ProviderSetupCategory::Model, + setup_method: ProviderSetupMethod::OauthBrowser, + group: ProviderSetupGroup::Default, + display_name: Some("xAI (SuperGrok)"), + description: Some("Grok via SuperGrok subscription"), + docs_url: Some("https://x.ai/grok"), + aliases: &[], + native_connect_query: Some("xAI Grok"), + binary_name: None, + setup_capabilities: setup_capabilities(false, true, false), + show_only_when_installed: false, + synthetic: false, + secret_field_default: None, + field_overrides: &[], + }, CuratedSetupMetadata { provider_id: "groq", category: ProviderSetupCategory::Model, diff --git a/crates/goose/src/providers/init.rs b/crates/goose/src/providers/init.rs index 3340f438478f..0ae29d2265ad 100644 --- a/crates/goose/src/providers/init.rs +++ b/crates/goose/src/providers/init.rs @@ -38,6 +38,7 @@ use super::{ snowflake::SnowflakeProvider, tetrate::TetrateProvider, xai::XaiProvider, + xai_oauth::XaiOAuthProvider, }; use crate::config::ExtensionConfig; use crate::model::ModelConfig; @@ -87,6 +88,7 @@ async fn init_registry() -> RwLock { registry.register::(false); registry.register::(true); registry.register::(false); + registry.register::(true); }); // Register cleanup functions for providers with cached state registry.set_cleanup( @@ -109,6 +111,10 @@ async fn init_registry() -> RwLock { "chatgpt_codex", Arc::new(|| Box::pin(ChatGptCodexProvider::cleanup())), ); + registry.set_cleanup( + "xai_oauth", + Arc::new(|| Box::pin(XaiOAuthProvider::cleanup())), + ); if let Err(e) = load_custom_providers_into_registry(&mut registry) { tracing::warn!("Failed to load custom providers: {}", e); diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index f5e4f952001d..b0fec1c9b4d5 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -58,6 +58,7 @@ pub mod usage_estimator; pub mod utils; pub mod xai; +pub mod xai_oauth; pub use init::{ cleanup_provider, create, create_with_default_model, create_with_named_model, diff --git a/crates/goose/src/providers/xai_oauth.rs b/crates/goose/src/providers/xai_oauth.rs new file mode 100644 index 000000000000..4d2f2ad8b011 --- /dev/null +++ b/crates/goose/src/providers/xai_oauth.rs @@ -0,0 +1,879 @@ +use super::api_client::{ApiClient, AuthMethod, AuthProvider}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; +use super::errors::ProviderError; +use super::openai_compatible::OpenAiCompatibleProvider; +use super::xai::{XAI_API_HOST, XAI_DEFAULT_MODEL, XAI_KNOWN_MODELS}; +use crate::config::paths::Paths; +use crate::conversation::message::Message; +use crate::model::ModelConfig; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use axum::{extract::Query, response::Html, routing::get, Router}; +use base64::Engine; +use chrono::{DateTime, Utc}; +use futures::future::BoxFuture; +use rmcp::model::Tool; +use serde::{Deserialize, Serialize}; +use sha2::Digest; +use std::io; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::{Arc, LazyLock}; +use tokio::sync::{oneshot, Mutex as TokioMutex}; + +// Public Grok-CLI OAuth client. xAI's auth server rejects loopback OAuth from +// non-allowlisted clients, so we reuse the Grok-CLI client_id that xAI ships +// for desktop OAuth flows. +const CLIENT_ID: &str = "b1a00492-073a-47ea-816f-4c329264a828"; + +const AUTHORIZE_URL: &str = "https://auth.x.ai/oauth2/authorize"; +const TOKEN_URL: &str = "https://auth.x.ai/oauth2/token"; +// RFC 8628 device authorization endpoint, surfaced via xAI's +// /.well-known/openid-configuration as `device_authorization_endpoint`. +const DEVICE_AUTHORIZATION_URL: &str = "https://auth.x.ai/oauth2/device/code"; +const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code"; + +const OAUTH_SCOPES: &[&str] = &[ + "openid", + "profile", + "email", + "offline_access", + "grok-cli:access", + "api:access", +]; + +// xAI rejects redirect_uris that don't match what was registered for the +// Grok-CLI client. The host:port pair is part of the registration, so we have +// to bind the loopback server to this exact port. +const OAUTH_HOST: [u8; 4] = [127, 0, 0, 1]; +const OAUTH_PORT: u16 = 56121; +const OAUTH_REDIRECT_PATH: &str = "/callback"; + +const OAUTH_TIMEOUT_SECS: u64 = 300; +const HTML_AUTO_CLOSE_TIMEOUT_MS: u64 = 2000; + +// Refresh skew: refresh tokens this long before stored expiry so a single +// long-running tool call doesn't have to recover from a mid-flight 401. +const ACCESS_TOKEN_REFRESH_SKEW_SECS: i64 = 120; + +// Device-code polling bounds. +const DEVICE_CODE_DEFAULT_INTERVAL_SECS: u64 = 5; +const DEVICE_CODE_MIN_INTERVAL_SECS: u64 = 1; +const DEVICE_CODE_SLOW_DOWN_INCREMENT_SECS: u64 = 5; +const DEVICE_CODE_DEFAULT_EXPIRES_SECS: u64 = 5 * 60; + +const XAI_OAUTH_PROVIDER_NAME: &str = "xai_oauth"; +const XAI_OAUTH_DOC_URL: &str = "https://x.ai/grok"; + +#[derive(Debug)] +struct XaiAuthState { + oauth_mutex: TokioMutex<()>, + refresh_mutex: TokioMutex<()>, +} + +impl XaiAuthState { + fn new() -> Self { + Self { + oauth_mutex: TokioMutex::new(()), + refresh_mutex: TokioMutex::new(()), + } + } + + fn instance() -> Arc { + Arc::clone(&XAI_AUTH_STATE) + } +} + +static XAI_AUTH_STATE: LazyLock> = + LazyLock::new(|| Arc::new(XaiAuthState::new())); + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct TokenData { + access_token: String, + refresh_token: String, + #[serde(default)] + id_token: Option, + expires_at: DateTime, +} + +#[derive(Debug, Clone)] +struct TokenCache { + cache_path: PathBuf, +} + +fn get_cache_path() -> PathBuf { + Paths::in_config_dir("xai_oauth/tokens.json") +} + +impl TokenCache { + fn new() -> Self { + let cache_path = get_cache_path(); + if let Some(parent) = cache_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + Self { cache_path } + } + + fn load(&self) -> Option { + let contents = std::fs::read_to_string(&self.cache_path).ok()?; + serde_json::from_str(&contents).ok() + } + + fn save(&self, token_data: &TokenData) -> Result<()> { + if let Some(parent) = self.cache_path.parent() { + std::fs::create_dir_all(parent)?; + } + let contents = serde_json::to_string(token_data)?; + std::fs::write(&self.cache_path, contents)?; + Ok(()) + } + + fn clear(&self) { + let _ = std::fs::remove_file(&self.cache_path); + } +} + +struct PkceChallenge { + verifier: String, + challenge: String, +} + +fn generate_pkce() -> PkceChallenge { + let verifier = nanoid::nanoid!(64); + let digest = sha2::Sha256::digest(verifier.as_bytes()); + let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); + PkceChallenge { + verifier, + challenge, + } +} + +fn generate_state() -> String { + nanoid::nanoid!(32) +} + +fn redirect_uri() -> String { + format!( + "http://{}.{}.{}.{}:{}{}", + OAUTH_HOST[0], OAUTH_HOST[1], OAUTH_HOST[2], OAUTH_HOST[3], OAUTH_PORT, OAUTH_REDIRECT_PATH + ) +} + +fn build_authorize_url(pkce: &PkceChallenge, state: &str, nonce: &str) -> Result { + let scopes = OAUTH_SCOPES.join(" "); + let redirect = redirect_uri(); + // `plan=generic` opts the consent screen into xAI's generic OAuth plan + // tier; without it, accounts.x.ai rejects loopback OAuth from + // non-allowlisted clients. `referrer=goose` lets xAI attribute + // goose-originated logins. + let params = [ + ("response_type", "code"), + ("client_id", CLIENT_ID), + ("redirect_uri", redirect.as_str()), + ("scope", scopes.as_str()), + ("code_challenge", pkce.challenge.as_str()), + ("code_challenge_method", "S256"), + ("state", state), + ("nonce", nonce), + ("plan", "generic"), + ("referrer", "goose"), + ]; + let query = serde_urlencoded::to_string(params)?; + Ok(format!("{}?{}", AUTHORIZE_URL, query)) +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: String, + #[serde(default)] + id_token: Option, + #[serde(default)] + expires_in: Option, +} + +async fn exchange_code_for_tokens(code: &str, pkce: &PkceChallenge) -> Result { + let client = reqwest::Client::new(); + let redirect = redirect_uri(); + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect.as_str()), + ("client_id", CLIENT_ID), + ("code_verifier", pkce.verifier.as_str()), + ]; + + let resp = client + .post(TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(anyhow!("xAI token exchange failed ({}): {}", status, text)); + } + + Ok(resp.json().await?) +} + +async fn refresh_access_token(refresh_token: &str) -> Result { + let client = reqwest::Client::new(); + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", CLIENT_ID), + ]; + + let resp = client + .post(TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(anyhow!("xAI token refresh failed ({}): {}", status, text)); + } + + Ok(resp.json().await?) +} + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default)] + verification_uri_complete: Option, + #[serde(default)] + expires_in: Option, + #[serde(default)] + interval: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct DeviceTokenErrorBody { + #[serde(default)] + error: Option, + #[serde(default)] + error_description: Option, +} + +async fn request_device_code() -> Result { + let client = reqwest::Client::new(); + let scopes = OAUTH_SCOPES.join(" "); + let params = [("client_id", CLIENT_ID), ("scope", scopes.as_str())]; + let resp = client + .post(DEVICE_AUTHORIZATION_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(anyhow!( + "xAI device code request failed ({}): {}", + status, + text + )); + } + Ok(resp.json().await?) +} + +async fn poll_device_code_token(device: &DeviceCodeResponse) -> Result { + let expires_secs = device + .expires_in + .filter(|v| *v > 0) + .unwrap_or(DEVICE_CODE_DEFAULT_EXPIRES_SECS); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(expires_secs); + let mut interval_secs = device + .interval + .filter(|v| *v > 0) + .unwrap_or(DEVICE_CODE_DEFAULT_INTERVAL_SECS) + .max(DEVICE_CODE_MIN_INTERVAL_SECS); + + let client = reqwest::Client::new(); + loop { + if tokio::time::Instant::now() >= deadline { + return Err(anyhow!("xAI device authorization timed out")); + } + let params = [ + ("grant_type", DEVICE_CODE_GRANT_TYPE), + ("client_id", CLIENT_ID), + ("device_code", device.device_code.as_str()), + ]; + let resp = client + .post(TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + if resp.status().is_success() { + return Ok(resp.json().await?); + } + let status = resp.status(); + let body: DeviceTokenErrorBody = resp.json().await.unwrap_or_default(); + match body.error.as_deref() { + Some("authorization_pending") => { + tokio::time::sleep(std::time::Duration::from_secs(interval_secs)).await; + } + Some("slow_down") => { + interval_secs += DEVICE_CODE_SLOW_DOWN_INCREMENT_SECS; + tokio::time::sleep(std::time::Duration::from_secs(interval_secs)).await; + } + Some("access_denied") | Some("authorization_denied") => { + return Err(anyhow!("xAI device authorization was denied")); + } + Some("expired_token") => { + return Err(anyhow!( + "xAI device code expired - please re-run goose configure" + )); + } + other => { + let detail = body + .error_description + .or_else(|| other.map(String::from)) + .unwrap_or_default(); + return Err(anyhow!( + "xAI device token exchange failed ({}): {}", + status, + detail + )); + } + } + } +} + +const HTML_SUCCESS_TEMPLATE: &str = r#" + + + goose - xAI Authorization Successful + + + +
+

Authorization Successful

+

You can close this window and return to goose.

+
+ + +"#; + +fn html_success() -> String { + HTML_SUCCESS_TEMPLATE.replace( + "__AUTO_CLOSE_TIMEOUT_MS__", + &HTML_AUTO_CLOSE_TIMEOUT_MS.to_string(), + ) +} + +fn html_error(error: &str) -> String { + let safe_error = v_htmlescape::escape_fmt(error); + format!( + r#" + + + goose - xAI Authorization Failed + + + +
+

Authorization Failed

+

An error occurred during authorization.

+
{}
+
+ +"#, + safe_error + ) +} + +#[derive(Deserialize)] +struct CallbackParams { + code: Option, + state: Option, + error: Option, + error_description: Option, +} + +fn oauth_callback_router( + expected_state: String, + tx: Arc>>>>, +) -> Router { + Router::new().route( + OAUTH_REDIRECT_PATH, + get(move |Query(params): Query| { + let tx = tx.clone(); + let expected = expected_state.clone(); + async move { + if let Some(error) = params.error { + let msg = params.error_description.unwrap_or(error); + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(&msg)); + } + + let code = match params.code { + Some(c) => c, + None => { + let msg = "Missing authorization code"; + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(msg)); + } + }; + + if params.state.as_deref() != Some(&expected) { + let msg = "Invalid state - potential CSRF attack"; + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(msg)); + } + + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Ok(code)); + } + Html(html_success()) + } + }), + ) +} + +async fn spawn_oauth_server(app: Router) -> Result> { + let addr = SocketAddr::from((OAUTH_HOST, OAUTH_PORT)); + let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { + if e.kind() == io::ErrorKind::AddrInUse { + anyhow!( + "xAI OAuth callback server failed to bind to {}: port {} is already in use \ + (likely another OAuth client like Grok-CLI is bound to the same port). \ + Stop the process using this port and try again, or use the device-code flow.", + addr, + OAUTH_PORT + ) + } else { + anyhow!( + "xAI OAuth callback server failed to bind to {}: {}", + addr, + e + ) + } + })?; + Ok(tokio::spawn(async move { + let server = axum::serve(listener, app); + let _ = server.await; + })) +} + +struct ServerHandleGuard(Option>); + +impl ServerHandleGuard { + fn new(handle: tokio::task::JoinHandle<()>) -> Self { + Self(Some(handle)) + } + + fn abort(&mut self) { + if let Some(handle) = self.0.take() { + handle.abort(); + } + } +} + +impl Drop for ServerHandleGuard { + fn drop(&mut self) { + self.abort(); + } +} + +async fn wait_for_oauth_code(rx: oneshot::Receiver>) -> Result { + let code_result = + tokio::time::timeout(std::time::Duration::from_secs(OAUTH_TIMEOUT_SECS), rx).await; + code_result + .map_err(|_| anyhow!("xAI OAuth flow timed out"))?? + .map_err(|e| anyhow!("xAI OAuth callback error: {}", e)) +} + +async fn perform_loopback_oauth_flow(auth_state: &XaiAuthState) -> Result { + let _guard = auth_state.oauth_mutex.try_lock().map_err(|_| { + anyhow!("Another xAI OAuth flow is already in progress; please try again later") + })?; + + let pkce = generate_pkce(); + let csrf_state = generate_state(); + let nonce = generate_state(); + let auth_url = build_authorize_url(&pkce, &csrf_state, &nonce)?; + + let (tx, rx) = oneshot::channel::>(); + let tx = Arc::new(TokioMutex::new(Some(tx))); + let app = oauth_callback_router(csrf_state.clone(), tx); + let server_handle = spawn_oauth_server(app).await?; + let mut server_guard = ServerHandleGuard::new(server_handle); + + if webbrowser::open(&auth_url).is_err() { + tracing::info!( + "Please open this URL in your browser to authorize goose with xAI:\n{}", + auth_url + ); + } + + let code_result = wait_for_oauth_code(rx).await; + server_guard.abort(); + let code = code_result?; + + let tokens = exchange_code_for_tokens(&code, &pkce).await?; + Ok(token_data_from_response(tokens)) +} + +async fn perform_device_code_flow() -> Result { + let device = request_device_code().await?; + let url = device + .verification_uri_complete + .clone() + .unwrap_or_else(|| device.verification_uri.clone()); + tracing::info!( + "xAI device authorization: open {} and enter code {}", + device.verification_uri, + device.user_code + ); + eprintln!( + "\nTo authorize goose with xAI, open this URL in any browser:\n {}\nand enter code: {}\n", + url, device.user_code + ); + let tokens = poll_device_code_token(&device).await?; + Ok(token_data_from_response(tokens)) +} + +fn token_data_from_response(tokens: TokenResponse) -> TokenData { + let expires_at = Utc::now() + chrono::Duration::seconds(tokens.expires_in.unwrap_or(3600)); + TokenData { + access_token: tokens.access_token, + refresh_token: tokens.refresh_token, + id_token: tokens.id_token, + expires_at, + } +} + +#[derive(Debug)] +struct XaiOAuthAuthProvider { + cache: TokenCache, + state: Arc, +} + +impl XaiOAuthAuthProvider { + fn new(state: Arc) -> Self { + Self { + cache: TokenCache::new(), + state, + } + } + + async fn get_valid_token(&self) -> Result { + if let Some(mut token_data) = self.cache.load() { + if token_data.expires_at + > Utc::now() + chrono::Duration::seconds(ACCESS_TOKEN_REFRESH_SKEW_SECS) + { + return Ok(token_data); + } + + // Single-flight refresh: collapse concurrent fetches onto one + // HTTP call so we don't replay a rotating refresh_token. + let _refresh_guard = self.state.refresh_mutex.lock().await; + if let Some(reloaded) = self.cache.load() { + if reloaded.expires_at + > Utc::now() + chrono::Duration::seconds(ACCESS_TOKEN_REFRESH_SKEW_SECS) + { + return Ok(reloaded); + } + token_data = reloaded; + } + + tracing::debug!("xAI access token expiring, attempting refresh"); + match refresh_access_token(&token_data.refresh_token).await { + Ok(new_tokens) => { + token_data.access_token = new_tokens.access_token; + if !new_tokens.refresh_token.is_empty() { + token_data.refresh_token = new_tokens.refresh_token; + } + if new_tokens.id_token.is_some() { + token_data.id_token = new_tokens.id_token; + } + token_data.expires_at = Utc::now() + + chrono::Duration::seconds(new_tokens.expires_in.unwrap_or(3600)); + self.cache.save(&token_data)?; + tracing::info!("xAI access token refreshed"); + return Ok(token_data); + } + Err(e) => { + tracing::warn!("xAI token refresh failed, will re-authenticate: {}", e); + self.cache.clear(); + } + } + } + + tracing::info!("Starting xAI OAuth flow (SuperGrok subscription)"); + let token_data = match perform_loopback_oauth_flow(self.state.as_ref()).await { + Ok(td) => td, + Err(e) => { + tracing::warn!( + "xAI loopback OAuth failed ({}); falling back to device-code flow", + e + ); + perform_device_code_flow().await? + } + }; + self.cache.save(&token_data)?; + Ok(token_data) + } +} + +#[async_trait] +impl AuthProvider for XaiOAuthAuthProvider { + async fn get_auth_header(&self) -> Result<(String, String)> { + let token_data = self.get_valid_token().await?; + Ok(( + "Authorization".to_string(), + format!("Bearer {}", token_data.access_token), + )) + } +} + +/// Delegating Provider that forwards chat/stream/etc. to an inner +/// `OpenAiCompatibleProvider` pointed at `https://api.x.ai/v1`, but overrides +/// `configure_oauth` so the desktop "Sign in" button (and any other caller of +/// `Provider::configure_oauth`) drives the loopback / device-code flow. +#[derive(serde::Serialize)] +pub struct XaiOAuthProvider { + #[serde(skip)] + inner: OpenAiCompatibleProvider, + #[serde(skip)] + auth_provider: Arc, +} + +impl XaiOAuthProvider { + pub async fn cleanup() -> Result<()> { + TokenCache::new().clear(); + Ok(()) + } +} + +#[async_trait] +impl Provider for XaiOAuthProvider { + fn get_name(&self) -> &str { + self.inner.get_name() + } + + fn get_model_config(&self) -> ModelConfig { + self.inner.get_model_config() + } + + async fn stream( + &self, + model_config: &ModelConfig, + session_id: &str, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + self.inner + .stream(model_config, session_id, system, messages, tools) + .await + } + + async fn fetch_supported_models(&self) -> Result, ProviderError> { + self.inner.fetch_supported_models().await + } + + async fn configure_oauth(&self) -> Result<(), ProviderError> { + // Preserve the previous token so a partially-completed sign-in + // attempt (e.g. user closes the browser) doesn't sign them out. + let previous_token = self.auth_provider.cache.load(); + self.auth_provider.cache.clear(); + + let flow_result = match perform_loopback_oauth_flow(self.auth_provider.state.as_ref()).await + { + Ok(td) => Ok(td), + Err(e) => { + tracing::warn!( + "xAI loopback OAuth failed ({}); falling back to device-code flow", + e + ); + perform_device_code_flow().await + } + }; + + let save_result = + flow_result.and_then(|token_data| self.auth_provider.cache.save(&token_data)); + + if let Err(e) = save_result { + if let Some(previous_token) = previous_token.as_ref() { + if self.auth_provider.cache.load().is_none() { + let _ = self.auth_provider.cache.save(previous_token); + } + } + return Err(ProviderError::Authentication(format!( + "xAI OAuth flow failed: {}", + e + ))); + } + Ok(()) + } +} + +impl ProviderDef for XaiOAuthProvider { + type Provider = Self; + + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + XAI_OAUTH_PROVIDER_NAME, + "xAI (SuperGrok Subscription)", + "Use your xAI SuperGrok subscription via OAuth instead of an API key. Falls back to a device-code flow on headless / remote machines.", + XAI_DEFAULT_MODEL, + XAI_KNOWN_MODELS.to_vec(), + XAI_OAUTH_DOC_URL, + vec![ + ConfigKey::new_oauth("XAI_OAUTH_TOKEN", true, true, None, false), + ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST), false), + ], + ) + } + + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { + let config = crate::config::Config::global(); + let host: String = config + .get_param("XAI_HOST") + .unwrap_or_else(|_| XAI_API_HOST.to_string()); + + let auth_provider = Arc::new(XaiOAuthAuthProvider::new(XaiAuthState::instance())); + let auth_for_client = Arc::clone(&auth_provider); + let api_client = ApiClient::new( + host, + AuthMethod::Custom(Box::new(SharedAuthProvider(auth_for_client))), + )?; + + let inner = OpenAiCompatibleProvider::new( + XAI_OAUTH_PROVIDER_NAME.to_string(), + api_client, + model, + String::new(), + ); + + Ok(Self { + inner, + auth_provider, + }) + }) + } + + fn inventory_configured() -> bool { + TokenCache::new().load().is_some() + } +} + +/// Adapter so the same `XaiOAuthAuthProvider` can be both owned by the +/// wrapper (for `configure_oauth`) and embedded as an `AuthMethod::Custom` +/// boxed `AuthProvider` in the inner `ApiClient`. +struct SharedAuthProvider(Arc); + +#[async_trait] +impl AuthProvider for SharedAuthProvider { + async fn get_auth_header(&self) -> Result<(String, String)> { + self.0.get_auth_header().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pkce_challenge_is_url_safe_base64_of_sha256_of_verifier() { + let pkce = generate_pkce(); + assert_eq!(pkce.verifier.len(), 64); + // S256 of a 64-char ASCII verifier => 32-byte digest => 43 base64url chars (no padding). + assert_eq!(pkce.challenge.len(), 43); + assert!(!pkce.challenge.contains('=')); + assert!(!pkce.challenge.contains('+')); + assert!(!pkce.challenge.contains('/')); + } + + #[test] + fn authorize_url_contains_required_oauth_params() { + let pkce = PkceChallenge { + verifier: "v".repeat(64), + challenge: "challenge-fixture".to_string(), + }; + let url = build_authorize_url(&pkce, "state-fixture", "nonce-fixture").unwrap(); + assert!(url.starts_with(AUTHORIZE_URL)); + assert!(url.contains(&format!("client_id={}", CLIENT_ID))); + assert!(url.contains("code_challenge=challenge-fixture")); + assert!(url.contains("code_challenge_method=S256")); + assert!(url.contains("state=state-fixture")); + assert!(url.contains("nonce=nonce-fixture")); + assert!(url.contains("plan=generic")); + assert!(url.contains("referrer=goose")); + assert!(url.contains("scope=openid")); + assert!(url.contains("offline_access")); + assert!(url.contains("grok-cli%3Aaccess")); + } + + #[test] + fn redirect_uri_matches_registered_grok_cli_value() { + // xAI rejects mismatched redirect_uris for the Grok-CLI client_id. + // This pins the loopback host/port that pairs with that client. + assert_eq!(redirect_uri(), "http://127.0.0.1:56121/callback"); + } + + #[test] + fn token_cache_path_lives_under_goose_config_dir() { + let path = get_cache_path(); + let s = path.to_string_lossy().into_owned(); + assert!( + s.contains("xai_oauth"), + "expected token path under xai_oauth/, got {}", + s + ); + assert!(s.ends_with("tokens.json")); + } +}