From 6873833019bb080e148c2bef4eaa3f249a2b59a0 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Tue, 12 May 2026 11:32:13 -0700 Subject: [PATCH 1/2] use org_id on POST /v1/api_key --- src/http.rs | 6 ++ src/setup/mod.rs | 152 ++++++++++++++++++++++++++++++++++++++++++++++- src/sync.rs | 1 - src/traces.rs | 1 - src/utils/mod.rs | 1 - 5 files changed, 157 insertions(+), 4 deletions(-) diff --git a/src/http.rs b/src/http.rs index 3287f4ef..4616ec8a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -28,6 +28,7 @@ pub struct ApiClient { http: Client, base_url: String, api_key: String, + org_id: String, org_name: String, } @@ -60,6 +61,7 @@ impl ApiClient { http, base_url: ctx.api_url.trim_end_matches('/').to_string(), api_key: ctx.login.api_key().context("login state missing API key")?, + org_id: ctx.login.org_id().unwrap_or_default(), org_name: ctx.login.org_name().unwrap_or_default(), }) } @@ -77,6 +79,10 @@ impl ApiClient { &self.base_url } + pub fn org_id(&self) -> &str { + &self.org_id + } + pub fn org_name(&self) -> &str { &self.org_name } diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 7e6a3ad6..64d00efd 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -1772,6 +1772,13 @@ async fn maybe_create_api_key_for_oauth(base: &BaseArgs, client: &ApiClient) -> key: String, } + let org_id = client.org_id().trim(); + if org_id.is_empty() { + bail!( + "setup could not determine the current org_id for API key creation; rerun with a direct API key or re-authenticate so setup can resolve the selected organization" + ); + } + let existing: Vec = client .get::("/v1/api_key") .await @@ -1790,7 +1797,7 @@ async fn maybe_create_api_key_for_oauth(base: &BaseArgs, client: &ApiClient) -> .expect("name sequence is infinite") }; - let body = serde_json::json!({ "name": name, "org_name": client.org_name() }); + let body = serde_json::json!({ "name": name, "org_id": org_id }); let created: CreatedKey = client.post("/v1/api_key", &body).await?; let explicitly_quiet = base.quiet && base.quiet_source.is_some(); @@ -4916,8 +4923,11 @@ fn print_mcp_human_report( #[cfg(test)] mod tests { use super::*; + use crate::auth::LoginContext; use std::env; use std::ffi::OsString; + use std::io::{Read, Write}; + use std::net::TcpListener; use std::sync::{Mutex, OnceLock}; use std::time::{SystemTime, UNIX_EPOCH}; @@ -4970,6 +4980,146 @@ mod tests { } } + fn make_login_context( + api_url: String, + app_url: String, + org_id: &str, + org_name: &str, + ) -> LoginContext { + let login = braintrust_sdk_rust::LoginState::new(); + let _ = login.set( + "test-api-key".to_string(), + org_id.to_string(), + org_name.to_string(), + api_url.clone(), + app_url.clone(), + ); + + LoginContext { + login, + api_url, + app_url, + } + } + + #[tokio::test] + async fn maybe_create_api_key_for_oauth_uses_org_id_in_request_body() { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let addr = listener.local_addr().expect("listener addr"); + let server = std::thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept list request"); + let mut buffer = [0u8; 4096]; + let read = stream.read(&mut buffer).expect("read list request"); + let request = String::from_utf8_lossy(&buffer[..read]); + assert!(request.starts_with("GET /v1/api_key HTTP/1.1")); + let response = concat!( + "HTTP/1.1 200 OK\r\n", + "Content-Type: application/json\r\n", + "Content-Length: 14\r\n", + "Connection: close\r\n", + "\r\n", + "{\"objects\":[]}" + ); + stream + .write_all(response.as_bytes()) + .expect("write list response"); + stream.flush().expect("flush list response"); + drop(stream); + + let (mut stream, _) = listener.accept().expect("accept create request"); + let mut header_buf = Vec::new(); + let mut temp = [0u8; 1024]; + let header_end; + loop { + let read = stream.read(&mut temp).expect("read create request"); + assert!(read > 0, "request closed before headers"); + header_buf.extend_from_slice(&temp[..read]); + if let Some(pos) = header_buf.windows(4).position(|w| w == b"\r\n\r\n") { + header_end = pos + 4; + break; + } + } + let headers = String::from_utf8_lossy(&header_buf[..header_end]); + assert!(headers.starts_with("POST /v1/api_key HTTP/1.1")); + let content_length = headers + .split("\r\n") + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().expect("content length")) + }) + .expect("content-length header"); + let mut body = header_buf[header_end..].to_vec(); + while body.len() < content_length { + let read = stream.read(&mut temp).expect("read request body"); + assert!(read > 0, "request closed before body completed"); + body.extend_from_slice(&temp[..read]); + } + let json: serde_json::Value = + serde_json::from_slice(&body[..content_length]).expect("parse request body"); + assert_eq!(json.get("org_id").and_then(|v| v.as_str()), Some("org_123")); + assert!(json.get("org_name").is_none()); + assert!(json.get("name").and_then(|v| v.as_str()).is_some()); + + let response = concat!( + "HTTP/1.1 200 OK\r\n", + "Content-Type: application/json\r\n", + "Content-Length: 17\r\n", + "Connection: close\r\n", + "\r\n", + "{\"key\":\"new-key\"}" + ); + stream + .write_all(response.as_bytes()) + .expect("write create response"); + stream.flush().expect("flush create response"); + }); + + let mut base = make_base_args(); + base.quiet = true; + base.quiet_source = Some(ArgValueSource::CommandLine); + + let api_url = format!("http://{addr}"); + let ctx = make_login_context( + api_url, + "https://app.example.test".to_string(), + "org_123", + "Acme", + ); + let client = ApiClient::new(&ctx).expect("client"); + + let key = maybe_create_api_key_for_oauth(&base, &client) + .await + .expect("create api key"); + assert_eq!(key, "new-key"); + + server.join().expect("server join"); + } + + #[tokio::test] + async fn maybe_create_api_key_for_oauth_requires_org_id() { + let mut base = make_base_args(); + base.quiet = true; + base.quiet_source = Some(ArgValueSource::CommandLine); + + let ctx = make_login_context( + "https://api.example.test".to_string(), + "https://app.example.test".to_string(), + "", + "Acme", + ); + let client = ApiClient::new(&ctx).expect("client"); + + let err = maybe_create_api_key_for_oauth(&base, &client) + .await + .expect_err("missing org_id should fail"); + let err_text = format!("{err:#}"); + assert!( + err_text.contains("org_id") && err_text.contains("API key creation"), + "unexpected error: {err_text}" + ); + } + #[test] fn single_path_agent_is_selected_by_default() { let detected = vec![DetectionSignal { diff --git a/src/sync.rs b/src/sync.rs index 1550807c..70ae79d0 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -26,7 +26,6 @@ use crate::experiments::api::create_experiment; use crate::http::ApiClient; use crate::projects::api::{create_project, list_projects, Project}; use crate::ui::{animations_enabled, fuzzy_select, is_quiet}; -use crate::utils::parse_duration_to_seconds; const STATE_SCHEMA_VERSION: u32 = 1; const DEFAULT_PULL_LIMIT: usize = 100; diff --git a/src/traces.rs b/src/traces.rs index 4cd03e99..0e8fd69e 100644 --- a/src/traces.rs +++ b/src/traces.rs @@ -36,7 +36,6 @@ use crate::args::BaseArgs; use crate::auth::{self, login}; use crate::http::ApiClient; use crate::ui::{fuzzy_select, is_interactive, with_spinner}; -use crate::utils::parse_duration_to_seconds; const MAX_TRACE_SPANS: usize = 5000; const MAX_BTQL_PAGE_LIMIT: usize = 1000; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index fcc4fcfd..e7de11df 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,7 +5,6 @@ mod ids; mod json_object; mod plurals; -pub use duration::parse_duration_to_seconds; pub use fs_atomic::write_text_atomic; pub use git::GitRepo; pub(crate) use ids::new_uuid_id; From 10cacb53158bce0013bf940f721c63aa9ccccf06 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Tue, 12 May 2026 12:33:17 -0700 Subject: [PATCH 2/2] simplfy test --- src/setup/mod.rs | 129 ++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 74 deletions(-) diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 64d00efd..747dd2f6 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -4924,11 +4924,11 @@ fn print_mcp_human_report( mod tests { use super::*; use crate::auth::LoginContext; + use actix_web::{web, App, HttpResponse, HttpServer}; use std::env; use std::ffi::OsString; - use std::io::{Read, Write}; use std::net::TcpListener; - use std::sync::{Mutex, OnceLock}; + use std::sync::{Arc, Mutex, OnceLock}; use std::time::{SystemTime, UNIX_EPOCH}; fn cwd_test_lock() -> &'static Mutex<()> { @@ -5002,84 +5002,53 @@ mod tests { } } + #[derive(Default)] + struct ApiKeyTestState { + create_request_body: Mutex>, + } + #[tokio::test] async fn maybe_create_api_key_for_oauth_uses_org_id_in_request_body() { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let server = std::thread::spawn(move || { - let (mut stream, _) = listener.accept().expect("accept list request"); - let mut buffer = [0u8; 4096]; - let read = stream.read(&mut buffer).expect("read list request"); - let request = String::from_utf8_lossy(&buffer[..read]); - assert!(request.starts_with("GET /v1/api_key HTTP/1.1")); - let response = concat!( - "HTTP/1.1 200 OK\r\n", - "Content-Type: application/json\r\n", - "Content-Length: 14\r\n", - "Connection: close\r\n", - "\r\n", - "{\"objects\":[]}" - ); - stream - .write_all(response.as_bytes()) - .expect("write list response"); - stream.flush().expect("flush list response"); - drop(stream); - - let (mut stream, _) = listener.accept().expect("accept create request"); - let mut header_buf = Vec::new(); - let mut temp = [0u8; 1024]; - let header_end; - loop { - let read = stream.read(&mut temp).expect("read create request"); - assert!(read > 0, "request closed before headers"); - header_buf.extend_from_slice(&temp[..read]); - if let Some(pos) = header_buf.windows(4).position(|w| w == b"\r\n\r\n") { - header_end = pos + 4; - break; - } - } - let headers = String::from_utf8_lossy(&header_buf[..header_end]); - assert!(headers.starts_with("POST /v1/api_key HTTP/1.1")); - let content_length = headers - .split("\r\n") - .find_map(|line| { - let (name, value) = line.split_once(':')?; - name.eq_ignore_ascii_case("content-length") - .then(|| value.trim().parse::().expect("content length")) - }) - .expect("content-length header"); - let mut body = header_buf[header_end..].to_vec(); - while body.len() < content_length { - let read = stream.read(&mut temp).expect("read request body"); - assert!(read > 0, "request closed before body completed"); - body.extend_from_slice(&temp[..read]); - } - let json: serde_json::Value = - serde_json::from_slice(&body[..content_length]).expect("parse request body"); - assert_eq!(json.get("org_id").and_then(|v| v.as_str()), Some("org_123")); - assert!(json.get("org_name").is_none()); - assert!(json.get("name").and_then(|v| v.as_str()).is_some()); - - let response = concat!( - "HTTP/1.1 200 OK\r\n", - "Content-Type: application/json\r\n", - "Content-Length: 17\r\n", - "Connection: close\r\n", - "\r\n", - "{\"key\":\"new-key\"}" - ); - stream - .write_all(response.as_bytes()) - .expect("write create response"); - stream.flush().expect("flush create response"); - }); + let state = Arc::new(ApiKeyTestState::default()); + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind mock server"); + let addr = listener.local_addr().expect("mock server addr"); + let api_url = format!("http://{addr}"); + let data = web::Data::new(state.clone()); + + let server = HttpServer::new(move || { + App::new() + .app_data(data.clone()) + .route( + "/v1/api_key", + web::get().to(|| async { + HttpResponse::Ok().json(serde_json::json!({ "objects": [] })) + }), + ) + .route( + "/v1/api_key", + web::post().to( + |state: web::Data>, + body: web::Json| async move { + *state + .create_request_body + .lock() + .expect("lock create request body") = Some(body.into_inner()); + HttpResponse::Ok().json(serde_json::json!({ "key": "new-key" })) + }, + ), + ) + }) + .workers(1) + .listen(listener) + .expect("listen mock server") + .run(); + let handle = server.handle(); + tokio::spawn(server); let mut base = make_base_args(); base.quiet = true; base.quiet_source = Some(ArgValueSource::CommandLine); - let api_url = format!("http://{addr}"); let ctx = make_login_context( api_url, "https://app.example.test".to_string(), @@ -5092,8 +5061,20 @@ mod tests { .await .expect("create api key"); assert_eq!(key, "new-key"); + let request_body = state + .create_request_body + .lock() + .expect("lock create request body") + .clone() + .expect("captured create request body"); + assert_eq!( + request_body.get("org_id").and_then(|v| v.as_str()), + Some("org_123") + ); + assert!(request_body.get("org_name").is_none()); + assert!(request_body.get("name").and_then(|v| v.as_str()).is_some()); - server.join().expect("server join"); + handle.stop(true).await; } #[tokio::test]