Skip to content

Commit 0adb6be

Browse files
jackjackbitsjack
authored andcommitted
Add xAI Provider Support for Grok Models (aaif-goose#2976)
Co-authored-by: jack <jack@deck.local>
1 parent e22a9eb commit 0adb6be

10 files changed

Lines changed: 227 additions & 2 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ tmp/
1212
# will have compiled files and executables
1313
debug/
1414
target/
15+
.goose/
1516

1617
# These are backup files generated by rustfmt
1718
**/*.rs.bk

crates/goose/src/model.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy<HashMap<&'static str, usize>> = Lazy::new(||
3030
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
3131
map.insert("llama3.2", 128_000);
3232
map.insert("llama3.3", 128_000);
33+
34+
// x.ai Grok models, https://docs.x.ai/docs/overview
35+
map.insert("grok", 131_072);
3336
map
3437
});
3538

crates/goose/src/providers/factory.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use super::{
1717
sagemaker_tgi::SageMakerTgiProvider,
1818
snowflake::SnowflakeProvider,
1919
venice::VeniceProvider,
20+
xai::XaiProvider,
2021
};
2122
use crate::model::ModelConfig;
2223
use anyhow::Result;
@@ -52,6 +53,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
5253
SageMakerTgiProvider::metadata(),
5354
VeniceProvider::metadata(),
5455
SnowflakeProvider::metadata(),
56+
XaiProvider::metadata(),
5557
]
5658
}
5759

@@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
128130
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
129131
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
130132
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
133+
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
131134
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
132135
}
133136
}
@@ -259,7 +262,7 @@ mod tests {
259262
}
260263

261264
// Set only the required lead model
262-
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
265+
env::set_var("GOOSE_LEAD_MODEL", "grok-3");
263266

264267
// This should use defaults for all other values
265268
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));

crates/goose/src/providers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ pub mod toolshim;
2424
pub mod utils;
2525
pub mod utils_universal_openai_stream;
2626
pub mod venice;
27+
pub mod xai;
2728

2829
pub use factory::{create, providers};

crates/goose/src/providers/xai.rs

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
use super::errors::ProviderError;
2+
use crate::message::Message;
3+
use crate::model::ModelConfig;
4+
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
5+
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
6+
use crate::providers::utils::get_model;
7+
use anyhow::Result;
8+
use async_trait::async_trait;
9+
use mcp_core::Tool;
10+
use reqwest::{Client, StatusCode};
11+
use serde_json::Value;
12+
use std::time::Duration;
13+
use url::Url;
14+
15+
pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
16+
pub const XAI_DEFAULT_MODEL: &str = "grok-3";
17+
pub const XAI_KNOWN_MODELS: &[&str] = &[
18+
"grok-3",
19+
"grok-3-fast",
20+
"grok-3-mini",
21+
"grok-3-mini-fast",
22+
"grok-2-vision-1212",
23+
"grok-2-image-1212",
24+
"grok-2-1212",
25+
"grok-3-latest",
26+
"grok-3-fast-latest",
27+
"grok-3-mini-latest",
28+
"grok-3-mini-fast-latest",
29+
"grok-2-vision",
30+
"grok-2-vision-latest",
31+
"grok-2-image",
32+
"grok-2-image-latest",
33+
"grok-2",
34+
"grok-2-latest",
35+
];
36+
37+
pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
38+
39+
#[derive(serde::Serialize)]
40+
pub struct XaiProvider {
41+
#[serde(skip)]
42+
client: Client,
43+
host: String,
44+
api_key: String,
45+
model: ModelConfig,
46+
}
47+
48+
impl Default for XaiProvider {
49+
fn default() -> Self {
50+
let model = ModelConfig::new(XaiProvider::metadata().default_model);
51+
XaiProvider::from_env(model).expect("Failed to initialize xAI provider")
52+
}
53+
}
54+
55+
impl XaiProvider {
56+
pub fn from_env(model: ModelConfig) -> Result<Self> {
57+
let config = crate::config::Config::global();
58+
let api_key: String = config.get_secret("XAI_API_KEY")?;
59+
let host: String = config
60+
.get_param("XAI_HOST")
61+
.unwrap_or_else(|_| XAI_API_HOST.to_string());
62+
63+
let client = Client::builder()
64+
.timeout(Duration::from_secs(600))
65+
.build()?;
66+
67+
Ok(Self {
68+
client,
69+
host,
70+
api_key,
71+
model,
72+
})
73+
}
74+
75+
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
76+
// Ensure the host ends with a slash for proper URL joining
77+
let host = if self.host.ends_with('/') {
78+
self.host.clone()
79+
} else {
80+
format!("{}/", self.host)
81+
};
82+
let base_url = Url::parse(&host)
83+
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
84+
let url = base_url.join("chat/completions").map_err(|e| {
85+
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
86+
})?;
87+
88+
tracing::debug!("xAI API URL: {}", url);
89+
tracing::debug!("xAI request model: {:?}", self.model.model_name);
90+
91+
let response = self
92+
.client
93+
.post(url)
94+
.header("Authorization", format!("Bearer {}", self.api_key))
95+
.json(&payload)
96+
.send()
97+
.await?;
98+
99+
let status = response.status();
100+
let payload: Option<Value> = response.json().await.ok();
101+
102+
match status {
103+
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
104+
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
105+
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
106+
Status: {}. Response: {:?}", status, payload)))
107+
}
108+
StatusCode::PAYLOAD_TOO_LARGE => {
109+
Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload)))
110+
}
111+
StatusCode::TOO_MANY_REQUESTS => {
112+
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
113+
}
114+
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
115+
Err(ProviderError::ServerError(format!("{:?}", payload)))
116+
}
117+
_ => {
118+
tracing::debug!(
119+
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
120+
);
121+
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
122+
}
123+
}
124+
}
125+
}
126+
127+
#[async_trait]
128+
impl Provider for XaiProvider {
129+
fn metadata() -> ProviderMetadata {
130+
ProviderMetadata::new(
131+
"xai",
132+
"xAI",
133+
"Grok models from xAI, including reasoning and multimodal capabilities",
134+
XAI_DEFAULT_MODEL,
135+
XAI_KNOWN_MODELS.to_vec(),
136+
XAI_DOC_URL,
137+
vec![
138+
ConfigKey::new("XAI_API_KEY", true, true, None),
139+
ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
140+
],
141+
)
142+
}
143+
144+
fn get_model_config(&self) -> ModelConfig {
145+
self.model.clone()
146+
}
147+
148+
#[tracing::instrument(
149+
skip(self, system, messages, tools),
150+
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
151+
)]
152+
async fn complete(
153+
&self,
154+
system: &str,
155+
messages: &[Message],
156+
tools: &[Tool],
157+
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
158+
let payload = create_request(
159+
&self.model,
160+
system,
161+
messages,
162+
tools,
163+
&super::utils::ImageFormat::OpenAi,
164+
)?;
165+
166+
let response = self.post(payload.clone()).await?;
167+
168+
let message = response_to_message(response.clone())?;
169+
let usage = match get_usage(&response) {
170+
Ok(usage) => usage,
171+
Err(ProviderError::UsageError(e)) => {
172+
tracing::debug!("Failed to get usage data: {}", e);
173+
Usage::default()
174+
}
175+
Err(e) => return Err(e),
176+
};
177+
let model = get_model(&response);
178+
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
179+
Ok((message, ProviderUsage::new(model, usage)))
180+
}
181+
}

documentation/docs/getting-started/providers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
3333
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
3434
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
3535
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
36+
| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) |
3637

3738

3839
## Configure Provider

ui/desktop/src/components/settings/providers/ProviderRegistry.tsx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
9797
],
9898
},
9999
},
100+
{
101+
name: 'xAI',
102+
details: {
103+
id: 'xai',
104+
name: 'xAI',
105+
description: 'Access Grok models from xAI, including reasoning and multimodal capabilities',
106+
parameters: [
107+
{
108+
name: 'XAI_API_KEY',
109+
is_secret: true,
110+
},
111+
{
112+
name: 'XAI_HOST',
113+
is_secret: false,
114+
default: 'https://api.x.ai/v1',
115+
},
116+
],
117+
},
118+
},
100119
{
101120
name: 'Google',
102121
details: {

ui/desktop/src/components/settings/providers/modal/subcomponents/ProviderLogo.tsx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png';
66
import DatabricksLogo from './icons/databricks@3x.png';
77
import OpenRouterLogo from './icons/openrouter@3x.png';
88
import SnowflakeLogo from './icons/snowflake@3x.png';
9+
import XaiLogo from './icons/xai@3x.png';
910
import DefaultLogo from './icons/default@3x.png';
1011

1112
// Map provider names to their logos
@@ -18,6 +19,7 @@ const providerLogos: Record<string, string> = {
1819
databricks: DatabricksLogo,
1920
openrouter: OpenRouterLogo,
2021
snowflake: SnowflakeLogo,
22+
xai: XaiLogo,
2123
default: DefaultLogo,
2224
};
2325

@@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) {
3032
const logoKey = providerName.toLowerCase();
3133
const logo = providerLogos[logoKey] || DefaultLogo;
3234

35+
// Special handling for xAI logo
36+
const isXai = logoKey === 'xai';
37+
const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {};
38+
39+
// Use smaller size for xAI logo to fit better in circle
40+
const imageClassName = isXai
41+
? 'w-8 h-8 object-contain' // Smaller size for xAI
42+
: 'w-16 h-16 object-contain'; // Default size for others
43+
3344
return (
3445
<div className="flex justify-center mb-2">
3546
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
36-
<img src={logo} alt={`${providerName} logo`} className="w-16 h-16 object-contain" />
47+
<img
48+
src={logo}
49+
alt={`${providerName} logo`}
50+
className={imageClassName}
51+
style={imageStyle}
52+
/>
3753
</div>
3854
</div>
3955
);
3.69 KB
Loading
3.69 KB
Loading

0 commit comments

Comments
 (0)