|
| 1 | +use super::errors::ProviderError; |
| 2 | +use crate::message::Message; |
| 3 | +use crate::model::ModelConfig; |
| 4 | +use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; |
| 5 | +use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; |
| 6 | +use crate::providers::utils::get_model; |
| 7 | +use anyhow::Result; |
| 8 | +use async_trait::async_trait; |
| 9 | +use mcp_core::Tool; |
| 10 | +use reqwest::{Client, StatusCode}; |
| 11 | +use serde_json::Value; |
| 12 | +use std::time::Duration; |
| 13 | +use url::Url; |
| 14 | + |
| 15 | +pub const XAI_API_HOST: &str = "https://api.x.ai/v1"; |
| 16 | +pub const XAI_DEFAULT_MODEL: &str = "grok-3"; |
| 17 | +pub const XAI_KNOWN_MODELS: &[&str] = &[ |
| 18 | + "grok-3", |
| 19 | + "grok-3-fast", |
| 20 | + "grok-3-mini", |
| 21 | + "grok-3-mini-fast", |
| 22 | + "grok-2-vision-1212", |
| 23 | + "grok-2-image-1212", |
| 24 | + "grok-2-1212", |
| 25 | + "grok-3-latest", |
| 26 | + "grok-3-fast-latest", |
| 27 | + "grok-3-mini-latest", |
| 28 | + "grok-3-mini-fast-latest", |
| 29 | + "grok-2-vision", |
| 30 | + "grok-2-vision-latest", |
| 31 | + "grok-2-image", |
| 32 | + "grok-2-image-latest", |
| 33 | + "grok-2", |
| 34 | + "grok-2-latest", |
| 35 | +]; |
| 36 | + |
| 37 | +pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview"; |
| 38 | + |
| 39 | +#[derive(serde::Serialize)] |
| 40 | +pub struct XaiProvider { |
| 41 | + #[serde(skip)] |
| 42 | + client: Client, |
| 43 | + host: String, |
| 44 | + api_key: String, |
| 45 | + model: ModelConfig, |
| 46 | +} |
| 47 | + |
| 48 | +impl Default for XaiProvider { |
| 49 | + fn default() -> Self { |
| 50 | + let model = ModelConfig::new(XaiProvider::metadata().default_model); |
| 51 | + XaiProvider::from_env(model).expect("Failed to initialize xAI provider") |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +impl XaiProvider { |
| 56 | + pub fn from_env(model: ModelConfig) -> Result<Self> { |
| 57 | + let config = crate::config::Config::global(); |
| 58 | + let api_key: String = config.get_secret("XAI_API_KEY")?; |
| 59 | + let host: String = config |
| 60 | + .get_param("XAI_HOST") |
| 61 | + .unwrap_or_else(|_| XAI_API_HOST.to_string()); |
| 62 | + |
| 63 | + let client = Client::builder() |
| 64 | + .timeout(Duration::from_secs(600)) |
| 65 | + .build()?; |
| 66 | + |
| 67 | + Ok(Self { |
| 68 | + client, |
| 69 | + host, |
| 70 | + api_key, |
| 71 | + model, |
| 72 | + }) |
| 73 | + } |
| 74 | + |
| 75 | + async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> { |
| 76 | + // Ensure the host ends with a slash for proper URL joining |
| 77 | + let host = if self.host.ends_with('/') { |
| 78 | + self.host.clone() |
| 79 | + } else { |
| 80 | + format!("{}/", self.host) |
| 81 | + }; |
| 82 | + let base_url = Url::parse(&host) |
| 83 | + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; |
| 84 | + let url = base_url.join("chat/completions").map_err(|e| { |
| 85 | + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) |
| 86 | + })?; |
| 87 | + |
| 88 | + tracing::debug!("xAI API URL: {}", url); |
| 89 | + tracing::debug!("xAI request model: {:?}", self.model.model_name); |
| 90 | + |
| 91 | + let response = self |
| 92 | + .client |
| 93 | + .post(url) |
| 94 | + .header("Authorization", format!("Bearer {}", self.api_key)) |
| 95 | + .json(&payload) |
| 96 | + .send() |
| 97 | + .await?; |
| 98 | + |
| 99 | + let status = response.status(); |
| 100 | + let payload: Option<Value> = response.json().await.ok(); |
| 101 | + |
| 102 | + match status { |
| 103 | + StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), |
| 104 | + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { |
| 105 | + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ |
| 106 | + Status: {}. Response: {:?}", status, payload))) |
| 107 | + } |
| 108 | + StatusCode::PAYLOAD_TOO_LARGE => { |
| 109 | + Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload))) |
| 110 | + } |
| 111 | + StatusCode::TOO_MANY_REQUESTS => { |
| 112 | + Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) |
| 113 | + } |
| 114 | + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { |
| 115 | + Err(ProviderError::ServerError(format!("{:?}", payload))) |
| 116 | + } |
| 117 | + _ => { |
| 118 | + tracing::debug!( |
| 119 | + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) |
| 120 | + ); |
| 121 | + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) |
| 122 | + } |
| 123 | + } |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +#[async_trait] |
| 128 | +impl Provider for XaiProvider { |
| 129 | + fn metadata() -> ProviderMetadata { |
| 130 | + ProviderMetadata::new( |
| 131 | + "xai", |
| 132 | + "xAI", |
| 133 | + "Grok models from xAI, including reasoning and multimodal capabilities", |
| 134 | + XAI_DEFAULT_MODEL, |
| 135 | + XAI_KNOWN_MODELS.to_vec(), |
| 136 | + XAI_DOC_URL, |
| 137 | + vec![ |
| 138 | + ConfigKey::new("XAI_API_KEY", true, true, None), |
| 139 | + ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)), |
| 140 | + ], |
| 141 | + ) |
| 142 | + } |
| 143 | + |
| 144 | + fn get_model_config(&self) -> ModelConfig { |
| 145 | + self.model.clone() |
| 146 | + } |
| 147 | + |
| 148 | + #[tracing::instrument( |
| 149 | + skip(self, system, messages, tools), |
| 150 | + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) |
| 151 | + )] |
| 152 | + async fn complete( |
| 153 | + &self, |
| 154 | + system: &str, |
| 155 | + messages: &[Message], |
| 156 | + tools: &[Tool], |
| 157 | + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { |
| 158 | + let payload = create_request( |
| 159 | + &self.model, |
| 160 | + system, |
| 161 | + messages, |
| 162 | + tools, |
| 163 | + &super::utils::ImageFormat::OpenAi, |
| 164 | + )?; |
| 165 | + |
| 166 | + let response = self.post(payload.clone()).await?; |
| 167 | + |
| 168 | + let message = response_to_message(response.clone())?; |
| 169 | + let usage = match get_usage(&response) { |
| 170 | + Ok(usage) => usage, |
| 171 | + Err(ProviderError::UsageError(e)) => { |
| 172 | + tracing::debug!("Failed to get usage data: {}", e); |
| 173 | + Usage::default() |
| 174 | + } |
| 175 | + Err(e) => return Err(e), |
| 176 | + }; |
| 177 | + let model = get_model(&response); |
| 178 | + super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); |
| 179 | + Ok((message, ProviderUsage::new(model, usage))) |
| 180 | + } |
| 181 | +} |
0 commit comments