|
| 1 | +//! Gateway error types. |
| 2 | +//! |
| 3 | +//! `GatewayError` is the unified error type for the gateway SDK layer |
| 4 | +//! (Layer 1-3). It covers validation errors, format bridging errors, |
| 5 | +//! provider HTTP errors, and stream errors. Each variant carries enough |
| 6 | +//! context for the proxy layer to produce an appropriate HTTP response. |
| 7 | +
|
| 8 | +use http::StatusCode; |
| 9 | +use serde_json::Value; |
| 10 | + |
| 11 | +/// Unified error type for the gateway SDK. |
| 12 | +#[derive(Debug, thiserror::Error)] |
| 13 | +pub enum GatewayError { |
| 14 | + // ── Client errors (not retryable) ── |
| 15 | + /// Request validation failed (e.g., missing required field). |
| 16 | + #[error("validation: {0}")] |
| 17 | + Validation(String), |
| 18 | + |
| 19 | + /// Format bridging failed (e.g., cannot map an Anthropic field to hub format). |
| 20 | + #[error("format bridge: {0}")] |
| 21 | + Bridge(String), |
| 22 | + |
| 23 | + /// Data transformation failed (e.g., JSON deserialization of provider response). |
| 24 | + #[error("data transform: {0}")] |
| 25 | + Transform(String), |
| 26 | + |
| 27 | + /// The requested format is not natively supported by the provider. |
| 28 | + #[error("format not natively supported by provider {provider}")] |
| 29 | + NativeNotSupported { provider: String }, |
| 30 | + |
| 31 | + // ── Provider errors (may be retryable) ── |
| 32 | + /// The upstream provider returned an error response. |
| 33 | + #[error("provider {provider} returned {status}: {body}")] |
| 34 | + Provider { |
| 35 | + status: StatusCode, |
| 36 | + body: Value, |
| 37 | + provider: String, |
| 38 | + retryable: bool, |
| 39 | + }, |
| 40 | + |
| 41 | + // ── Infrastructure errors (usually retryable) ── |
| 42 | + /// HTTP transport error (connection, timeout, etc.). |
| 43 | + #[error("HTTP: {0}")] |
| 44 | + Http(#[source] reqwest::Error), |
| 45 | + |
| 46 | + /// Error during stream processing. |
| 47 | + #[error("stream: {0}")] |
| 48 | + Stream(String), |
| 49 | +} |
| 50 | + |
| 51 | +impl GatewayError { |
| 52 | + /// Whether this error is safe to retry. |
| 53 | + pub fn is_retryable(&self) -> bool { |
| 54 | + match self { |
| 55 | + Self::Provider { retryable, .. } => *retryable, |
| 56 | + Self::Http(e) => e.is_timeout() || e.is_connect(), |
| 57 | + Self::Stream(_) => true, |
| 58 | + _ => false, |
| 59 | + } |
| 60 | + } |
| 61 | + |
| 62 | + /// Map to an HTTP status code for proxy-layer responses. |
| 63 | + pub fn status_code(&self) -> StatusCode { |
| 64 | + match self { |
| 65 | + Self::Validation(_) | Self::Bridge(_) => StatusCode::BAD_REQUEST, |
| 66 | + Self::Transform(_) => StatusCode::UNPROCESSABLE_ENTITY, |
| 67 | + Self::Provider { status, .. } => *status, |
| 68 | + Self::Http(_) | Self::Stream(_) => StatusCode::BAD_GATEWAY, |
| 69 | + Self::NativeNotSupported { .. } => StatusCode::NOT_IMPLEMENTED, |
| 70 | + } |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +/// Convenience alias for gateway results. |
| 75 | +pub type Result<T> = std::result::Result<T, GatewayError>; |
| 76 | + |
| 77 | +#[cfg(test)] |
| 78 | +mod tests { |
| 79 | + use serde_json::json; |
| 80 | + |
| 81 | + use super::*; |
| 82 | + |
| 83 | + #[test] |
| 84 | + fn validation_not_retryable() { |
| 85 | + let e = GatewayError::Validation("missing field".into()); |
| 86 | + assert!(!e.is_retryable()); |
| 87 | + assert_eq!(e.status_code(), StatusCode::BAD_REQUEST); |
| 88 | + } |
| 89 | + |
| 90 | + #[test] |
| 91 | + fn bridge_not_retryable() { |
| 92 | + let e = GatewayError::Bridge("cannot map field X".into()); |
| 93 | + assert!(!e.is_retryable()); |
| 94 | + assert_eq!(e.status_code(), StatusCode::BAD_REQUEST); |
| 95 | + } |
| 96 | + |
| 97 | + #[test] |
| 98 | + fn transform_not_retryable() { |
| 99 | + let e = GatewayError::Transform("bad json".into()); |
| 100 | + assert!(!e.is_retryable()); |
| 101 | + assert_eq!(e.status_code(), StatusCode::UNPROCESSABLE_ENTITY); |
| 102 | + } |
| 103 | + |
| 104 | + #[test] |
| 105 | + fn native_not_supported() { |
| 106 | + let e = GatewayError::NativeNotSupported { |
| 107 | + provider: "gemini".into(), |
| 108 | + }; |
| 109 | + assert!(!e.is_retryable()); |
| 110 | + assert_eq!(e.status_code(), StatusCode::NOT_IMPLEMENTED); |
| 111 | + assert!(e.to_string().contains("gemini")); |
| 112 | + } |
| 113 | + |
| 114 | + #[test] |
| 115 | + fn provider_retryable_when_flagged() { |
| 116 | + let e = GatewayError::Provider { |
| 117 | + status: StatusCode::TOO_MANY_REQUESTS, |
| 118 | + body: json!({"error": "rate limited"}), |
| 119 | + provider: "openai".into(), |
| 120 | + retryable: true, |
| 121 | + }; |
| 122 | + assert!(e.is_retryable()); |
| 123 | + assert_eq!(e.status_code(), StatusCode::TOO_MANY_REQUESTS); |
| 124 | + } |
| 125 | + |
| 126 | + #[test] |
| 127 | + fn provider_not_retryable_when_not_flagged() { |
| 128 | + let e = GatewayError::Provider { |
| 129 | + status: StatusCode::BAD_REQUEST, |
| 130 | + body: json!({"error": "bad request"}), |
| 131 | + provider: "anthropic".into(), |
| 132 | + retryable: false, |
| 133 | + }; |
| 134 | + assert!(!e.is_retryable()); |
| 135 | + assert_eq!(e.status_code(), StatusCode::BAD_REQUEST); |
| 136 | + } |
| 137 | + |
| 138 | + #[test] |
| 139 | + fn stream_error_retryable() { |
| 140 | + let e = GatewayError::Stream("connection reset".into()); |
| 141 | + assert!(e.is_retryable()); |
| 142 | + assert_eq!(e.status_code(), StatusCode::BAD_GATEWAY); |
| 143 | + } |
| 144 | + |
| 145 | + #[test] |
| 146 | + fn display_messages() { |
| 147 | + assert_eq!( |
| 148 | + GatewayError::Validation("x".into()).to_string(), |
| 149 | + "validation: x" |
| 150 | + ); |
| 151 | + assert_eq!( |
| 152 | + GatewayError::Bridge("y".into()).to_string(), |
| 153 | + "format bridge: y" |
| 154 | + ); |
| 155 | + let provider_err = GatewayError::Provider { |
| 156 | + status: StatusCode::INTERNAL_SERVER_ERROR, |
| 157 | + body: json!("err"), |
| 158 | + provider: "openai".into(), |
| 159 | + retryable: false, |
| 160 | + }; |
| 161 | + assert!(provider_err.to_string().contains("openai")); |
| 162 | + assert!(provider_err.to_string().contains("500")); |
| 163 | + } |
| 164 | +} |
0 commit comments