diff --git a/templates/rust/examples/basic_usage.rs.twig b/templates/rust/examples/basic_usage.rs.twig index 16b6cff627..6afb9b9e77 100644 --- a/templates/rust/examples/basic_usage.rs.twig +++ b/templates/rust/examples/basic_usage.rs.twig @@ -11,9 +11,9 @@ use {{ sdk.cratePackage | default('appwrite') | rustCrateName }}::{ async fn main() -> Result<(), Box> { // Initialize the client let client = Client::new() - .set_endpoint("{{ spec.endpoint }}") // Your API Endpoint - .set_project("5df5acd0d48c2") // Your project ID - .set_key("919c2d18fb5d4...a2ae413da83346ad2"); // Your secret API key + .set_endpoint("{{ spec.endpoint }}")? // Your API Endpoint + .set_project("5df5acd0d48c2")? // Your project ID + .set_key("919c2d18fb5d4...a2ae413da83346ad2")?; // Your secret API key println!("🚀 {{ spec.title }} Rust SDK Example"); println!("Connected to: {}", client.endpoint()); @@ -129,9 +129,9 @@ async fn main() -> Result<(), Box> { // This will likely fail with invalid credentials let invalid_client = Client::new() - .set_endpoint("{{ spec.endpoint }}") - .set_project("invalid-project") - .set_key("invalid-key"); + .set_endpoint("{{ spec.endpoint }}")? + .set_project("invalid-project")? + .set_key("invalid-key")?; let invalid_users = Users::new(&invalid_client); diff --git a/templates/rust/src/client.rs.twig b/templates/rust/src/client.rs.twig index 0c2c68016d..4fb7f8d439 100644 --- a/templates/rust/src/client.rs.twig +++ b/templates/rust/src/client.rs.twig @@ -88,9 +88,12 @@ impl Client { /// Create a new {{ spec.title }} client pub fn new() -> Self { let mut headers = HeaderMap::new(); + // SAFETY: Spec-defined header values are compile-time constants and always valid ASCII. {% for key, header in spec.global.defaultHeaders %} headers.insert("{{ key }}", "{{ header }}".parse().unwrap()); {% endfor %} + // SAFETY: SDK metadata values are also compile-time constants; OS/ARCH are + // guaranteed-valid ASCII from std::env::consts. headers.insert("user-agent", format!("{{ spec.title }}RustSDK/{{ sdk.version }} ({}; {})", std::env::consts::OS, std::env::consts::ARCH).parse().unwrap()); headers.insert("x-sdk-name", "{{ sdk.name }}".parse().unwrap()); headers.insert("x-sdk-platform", "{{ sdk.platform }}".parse().unwrap()); @@ -138,72 +141,93 @@ impl Client { } /// Set the API endpoint - pub fn set_endpoint>(&self, endpoint: S) -> Self { + pub fn set_endpoint>(&self, endpoint: S) -> Result { let endpoint = endpoint.into(); - if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") { - panic!("Invalid endpoint URL: {}. Endpoint must start with http:// or https://", endpoint); + let parsed = url::Url::parse(&endpoint).map_err(|e| {{ spec.title | caseUcfirst }}Error::new( + 400, + format!("Invalid endpoint URL: {}", e), + None, + String::new(), + ))?; + if parsed.scheme() != "http" && parsed.scheme() != "https" { + return Err({{ spec.title | caseUcfirst }}Error::new( + 400, + format!("Invalid endpoint URL scheme '{}': must be http or https", parsed.scheme()), + None, + String::new(), + )); } self.state.rcu(|state| { let mut next = (**state).clone(); next.config.endpoint = endpoint.clone(); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Set the project ID - pub fn set_project>(&self, project: S) -> Self { + pub fn set_project>(&self, project: S) -> Result { let project = project.into(); + let value: reqwest::header::HeaderValue = project.parse() + .map_err(|e| {{ spec.title | caseUcfirst }}Error::new(400, format!("Invalid header value for x-appwrite-project: {}", e), None, String::new()))?; self.state.rcu(|state| { let mut next = (**state).clone(); - next.config.headers.insert("x-appwrite-project", project.clone().parse().unwrap()); + next.config.headers.insert("x-appwrite-project", value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Set the API key - pub fn set_key>(&self, key: S) -> Self { + pub fn set_key>(&self, key: S) -> Result { let key = key.into(); + let value: reqwest::header::HeaderValue = key.parse() + .map_err(|e| {{ spec.title | caseUcfirst }}Error::new(400, format!("Invalid header value for x-appwrite-key: {}", e), None, String::new()))?; self.state.rcu(|state| { let mut next = (**state).clone(); - next.config.headers.insert("x-appwrite-key", key.clone().parse().unwrap()); + next.config.headers.insert("x-appwrite-key", value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Set the JWT token - pub fn set_jwt>(&self, jwt: S) -> Self { + pub fn set_jwt>(&self, jwt: S) -> Result { let jwt = jwt.into(); + let value: reqwest::header::HeaderValue = jwt.parse() + .map_err(|e| {{ spec.title | caseUcfirst }}Error::new(400, format!("Invalid header value for x-appwrite-jwt: {}", e), None, String::new()))?; self.state.rcu(|state| { let mut next = (**state).clone(); - next.config.headers.insert("x-appwrite-jwt", jwt.clone().parse().unwrap()); + next.config.headers.insert("x-appwrite-jwt", value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Set the locale - pub fn set_locale>(&self, locale: S) -> Self { + pub fn set_locale>(&self, locale: S) -> Result { let locale = locale.into(); + let value: reqwest::header::HeaderValue = locale.parse() + .map_err(|e| {{ spec.title | caseUcfirst }}Error::new(400, format!("Invalid header value for x-appwrite-locale: {}", e), None, String::new()))?; self.state.rcu(|state| { let mut next = (**state).clone(); - next.config.headers.insert("x-appwrite-locale", locale.clone().parse().unwrap()); + next.config.headers.insert("x-appwrite-locale", value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Set the session - pub fn set_session>(&self, session: S) -> Self { + pub fn set_session>(&self, session: S) -> Result { let session = session.into(); + let value: reqwest::header::HeaderValue = session.parse() + .map_err(|e| {{ spec.title | caseUcfirst }}Error::new(400, format!("Invalid header value for x-appwrite-session: {}", e), None, String::new()))?; self.state.rcu(|state| { let mut next = (**state).clone(); - next.config.headers.insert("x-appwrite-session", session.clone().parse().unwrap()); + next.config.headers.insert("x-appwrite-session", value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Enable or disable self-signed certificates @@ -245,23 +269,21 @@ impl Client { } /// Add a custom header - pub fn add_header, V: AsRef>(&self, key: K, value: V) -> Self { + pub fn add_header, V: AsRef>(&self, key: K, value: V) -> Result { use reqwest::header::{HeaderName, HeaderValue}; let key = key.as_ref().to_string(); let value = value.as_ref().to_string(); + let header_name: HeaderName = key.parse()?; + let header_value: HeaderValue = value.parse()?; + self.state.rcu(|state| { let mut next = (**state).clone(); - if let (Ok(header_name), Ok(header_value)) = ( - key.parse::(), - value.parse::(), - ) { - next.config.headers.insert(header_name, header_value); - } + next.config.headers.insert(header_name.clone(), header_value.clone()); Arc::new(next) }); - self.clone() + Ok(self.clone()) } /// Get a copy of the current request headers @@ -915,10 +937,24 @@ mod tests { #[test] fn test_client_builder_pattern() { let client = Client::new() - .set_endpoint("https://custom.example.com/v1") - .set_project("test-project") - .set_key("test-key"); + .set_endpoint("https://custom.example.com/v1").unwrap() + .set_project("test-project").unwrap() + .set_key("test-key").unwrap(); assert_eq!(client.endpoint(), "https://custom.example.com/v1"); } + + #[test] + fn test_invalid_endpoint() { + let client = Client::new(); + let err = client.set_endpoint("htp://cloud.appwrite.io/v1").unwrap_err(); + assert_eq!(err.code, 400); + assert!(err.message.contains("Invalid endpoint URL")); + } + + #[test] + fn test_invalid_header_value() { + let client = Client::new(); + assert!(client.set_key("my\nkey").is_err()); + } } diff --git a/templates/rust/src/error.rs.twig b/templates/rust/src/error.rs.twig index 26ee0be65f..4e5c37f77e 100644 --- a/templates/rust/src/error.rs.twig +++ b/templates/rust/src/error.rs.twig @@ -69,6 +69,18 @@ impl From for {{ spec.title | caseUcfirst }}Error { } } +impl From for {{ spec.title | caseUcfirst }}Error { + fn from(err: reqwest::header::InvalidHeaderValue) -> Self { + Self::new(400, format!("Invalid header value: {}", err), None, String::new()) + } +} + +impl From for {{ spec.title | caseUcfirst }}Error { + fn from(err: reqwest::header::InvalidHeaderName) -> Self { + Self::new(400, format!("Invalid header name: {}", err), None, String::new()) + } +} + /// {{ spec.title }} specific error response structure #[derive(Debug, serde::Deserialize)] pub struct ErrorResponse { diff --git a/templates/rust/src/lib.rs.twig b/templates/rust/src/lib.rs.twig index 3616be4a57..cb443cba1c 100644 --- a/templates/rust/src/lib.rs.twig +++ b/templates/rust/src/lib.rs.twig @@ -19,9 +19,9 @@ //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! let client = Client::new() -//! .set_endpoint("{{ spec.endpoint }}") -//! .set_project("your-project-id") -//! .set_key("your-api-key"); +//! .set_endpoint("{{ spec.endpoint }}")? +//! .set_project("your-project-id")? +//! .set_key("your-api-key")?; //! //! // Use the client to make API calls //! Ok(()) diff --git a/templates/rust/tests/tests.rs b/templates/rust/tests/tests.rs index a10a0e1fad..11317ebb1c 100644 --- a/templates/rust/tests/tests.rs +++ b/templates/rust/tests/tests.rs @@ -8,10 +8,10 @@ async fn main() -> Result<(), Box> { let string_in_array = vec!["string in array".to_string()]; let client = Client::new() - .set_endpoint("http://mockapi/v1") - .set_project("appwrite") - .set_key("apikey") - .add_header("Origin", "http://localhost"); + .set_endpoint("http://mockapi/v1")? + .set_project("appwrite")? + .set_key("apikey")? + .add_header("Origin", "http://localhost")?; println!("\n\nTest Started"); let sdk_headers = client.get_headers(); @@ -134,7 +134,16 @@ async fn test_general_service(client: &Client, string_in_array: &[String]) -> Re }, } - println!("Invalid endpoint URL: htp://cloud.appwrite.io/v1"); + match Client::new().set_endpoint("htp://cloud.appwrite.io/v1") { + Ok(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Expected validation failure for invalid endpoint but got Ok", + ) + .into()) + }, + Err(e) => println!("{}", e.message), + } let _ = general.empty().await;