diff --git a/src/http.rs b/src/http.rs index 3287f4ef..4616ec8a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -28,6 +28,7 @@ pub struct ApiClient { http: Client, base_url: String, api_key: String, + org_id: String, org_name: String, } @@ -60,6 +61,7 @@ impl ApiClient { http, base_url: ctx.api_url.trim_end_matches('/').to_string(), api_key: ctx.login.api_key().context("login state missing API key")?, + org_id: ctx.login.org_id().unwrap_or_default(), org_name: ctx.login.org_name().unwrap_or_default(), }) } @@ -77,6 +79,10 @@ impl ApiClient { &self.base_url } + pub fn org_id(&self) -> &str { + &self.org_id + } + pub fn org_name(&self) -> &str { &self.org_name } diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 7e6a3ad6..747dd2f6 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -1772,6 +1772,13 @@ async fn maybe_create_api_key_for_oauth(base: &BaseArgs, client: &ApiClient) -> key: String, } + let org_id = client.org_id().trim(); + if org_id.is_empty() { + bail!( + "setup could not determine the current org_id for API key creation; rerun with a direct API key or re-authenticate so setup can resolve the selected organization" + ); + } + let existing: Vec = client .get::("/v1/api_key") .await @@ -1790,7 +1797,7 @@ async fn maybe_create_api_key_for_oauth(base: &BaseArgs, client: &ApiClient) -> .expect("name sequence is infinite") }; - let body = serde_json::json!({ "name": name, "org_name": client.org_name() }); + let body = serde_json::json!({ "name": name, "org_id": org_id }); let created: CreatedKey = client.post("/v1/api_key", &body).await?; let explicitly_quiet = base.quiet && base.quiet_source.is_some(); @@ -4916,9 +4923,12 @@ fn print_mcp_human_report( #[cfg(test)] mod tests { use super::*; + use crate::auth::LoginContext; + use actix_web::{web, App, HttpResponse, HttpServer}; use std::env; use std::ffi::OsString; - use std::sync::{Mutex, OnceLock}; + use std::net::TcpListener; + use std::sync::{Arc, Mutex, OnceLock}; use std::time::{SystemTime, UNIX_EPOCH}; fn cwd_test_lock() -> &'static Mutex<()> { @@ -4970,6 +4980,127 @@ mod tests { } } + fn make_login_context( + api_url: String, + app_url: String, + org_id: &str, + org_name: &str, + ) -> LoginContext { + let login = braintrust_sdk_rust::LoginState::new(); + let _ = login.set( + "test-api-key".to_string(), + org_id.to_string(), + org_name.to_string(), + api_url.clone(), + app_url.clone(), + ); + + LoginContext { + login, + api_url, + app_url, + } + } + + #[derive(Default)] + struct ApiKeyTestState { + create_request_body: Mutex>, + } + + #[tokio::test] + async fn maybe_create_api_key_for_oauth_uses_org_id_in_request_body() { + let state = Arc::new(ApiKeyTestState::default()); + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind mock server"); + let addr = listener.local_addr().expect("mock server addr"); + let api_url = format!("http://{addr}"); + let data = web::Data::new(state.clone()); + + let server = HttpServer::new(move || { + App::new() + .app_data(data.clone()) + .route( + "/v1/api_key", + web::get().to(|| async { + HttpResponse::Ok().json(serde_json::json!({ "objects": [] })) + }), + ) + .route( + "/v1/api_key", + web::post().to( + |state: web::Data>, + body: web::Json| async move { + *state + .create_request_body + .lock() + .expect("lock create request body") = Some(body.into_inner()); + HttpResponse::Ok().json(serde_json::json!({ "key": "new-key" })) + }, + ), + ) + }) + .workers(1) + .listen(listener) + .expect("listen mock server") + .run(); + let handle = server.handle(); + tokio::spawn(server); + + let mut base = make_base_args(); + base.quiet = true; + base.quiet_source = Some(ArgValueSource::CommandLine); + + let ctx = make_login_context( + api_url, + "https://app.example.test".to_string(), + "org_123", + "Acme", + ); + let client = ApiClient::new(&ctx).expect("client"); + + let key = maybe_create_api_key_for_oauth(&base, &client) + .await + .expect("create api key"); + assert_eq!(key, "new-key"); + let request_body = state + .create_request_body + .lock() + .expect("lock create request body") + .clone() + .expect("captured create request body"); + assert_eq!( + request_body.get("org_id").and_then(|v| v.as_str()), + Some("org_123") + ); + assert!(request_body.get("org_name").is_none()); + assert!(request_body.get("name").and_then(|v| v.as_str()).is_some()); + + handle.stop(true).await; + } + + #[tokio::test] + async fn maybe_create_api_key_for_oauth_requires_org_id() { + let mut base = make_base_args(); + base.quiet = true; + base.quiet_source = Some(ArgValueSource::CommandLine); + + let ctx = make_login_context( + "https://api.example.test".to_string(), + "https://app.example.test".to_string(), + "", + "Acme", + ); + let client = ApiClient::new(&ctx).expect("client"); + + let err = maybe_create_api_key_for_oauth(&base, &client) + .await + .expect_err("missing org_id should fail"); + let err_text = format!("{err:#}"); + assert!( + err_text.contains("org_id") && err_text.contains("API key creation"), + "unexpected error: {err_text}" + ); + } + #[test] fn single_path_agent_is_selected_by_default() { let detected = vec![DetectionSignal { diff --git a/src/sync.rs b/src/sync.rs index 1550807c..70ae79d0 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -26,7 +26,6 @@ use crate::experiments::api::create_experiment; use crate::http::ApiClient; use crate::projects::api::{create_project, list_projects, Project}; use crate::ui::{animations_enabled, fuzzy_select, is_quiet}; -use crate::utils::parse_duration_to_seconds; const STATE_SCHEMA_VERSION: u32 = 1; const DEFAULT_PULL_LIMIT: usize = 100; diff --git a/src/traces.rs b/src/traces.rs index 4cd03e99..0e8fd69e 100644 --- a/src/traces.rs +++ b/src/traces.rs @@ -36,7 +36,6 @@ use crate::args::BaseArgs; use crate::auth::{self, login}; use crate::http::ApiClient; use crate::ui::{fuzzy_select, is_interactive, with_spinner}; -use crate::utils::parse_duration_to_seconds; const MAX_TRACE_SPANS: usize = 5000; const MAX_BTQL_PAGE_LIMIT: usize = 1000; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index fcc4fcfd..e7de11df 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,7 +5,6 @@ mod ids; mod json_object; mod plurals; -pub use duration::parse_duration_to_seconds; pub use fs_atomic::write_text_atomic; pub use git::GitRepo; pub(crate) use ids::new_uuid_id;