|
| 1 | +//! DM-7 — JWT extraction middleware + `ActorContext` population. |
| 2 | +//! |
| 3 | +//! **Phase 1 (this file):** Extract and decode JWT payload (base64), |
| 4 | +//! populate `ActorContext`. No signature verification — that requires |
| 5 | +//! a JWK endpoint or static key, which is deployment-specific. |
| 6 | +//! |
| 7 | +//! **Phase 2 (future):** Plug in real verification via a `JwkSetProvider` |
| 8 | +//! trait. The `JwtMiddleware::extract_actor` API won't change — only the |
| 9 | +//! internal verification step gets wired. |
| 10 | +//! |
| 11 | +//! # JWT Payload Shape (expected claims) |
| 12 | +//! |
| 13 | +//! ```json |
| 14 | +//! { |
| 15 | +//! "sub": "user@example.com", |
| 16 | +//! "tenant_id": 42, |
| 17 | +//! "roles": ["viewer", "editor"] |
| 18 | +//! } |
| 19 | +//! ``` |
| 20 | +//! |
| 21 | +//! - `sub` (required) — maps to `ActorContext.actor_id`. |
| 22 | +//! - `tenant_id` (required) — maps to `ActorContext.tenant_id` (`TenantId = u64`). |
| 23 | +//! - `roles` (optional) — maps to `ActorContext.roles`. Defaults to `[]`. |
| 24 | +//! |
| 25 | +//! # Zero New Dependencies |
| 26 | +//! |
| 27 | +//! Uses `serde` + `serde_json` (already gated under `[auth]` feature). |
| 28 | +//! Base64 URL-safe decoding is implemented inline (~40 lines) — no |
| 29 | +//! `base64` crate, no `jsonwebtoken` crate. |
| 30 | +//! |
| 31 | +//! Plan: `.claude/plans/callcenter-membrane-v1.md` § DM-7 |
| 32 | +
|
| 33 | +use lance_graph_contract::auth::{ActorContext, AuthError}; |
| 34 | +use serde::Deserialize; |
| 35 | + |
| 36 | +/// JWT extraction middleware. |
| 37 | +/// |
| 38 | +/// Phase 1: base64-decode the payload section of a JWT and extract |
| 39 | +/// `sub`, `tenant_id`, and `roles` into an `ActorContext`. |
| 40 | +/// |
| 41 | +/// No signature verification in Phase 1 — the token is trusted as-is. |
| 42 | +/// Phase 2 will add a `JwkSetProvider` trait for real verification. |
| 43 | +pub struct JwtMiddleware; |
| 44 | + |
| 45 | +/// Deserialization target for the JWT payload claims we care about. |
| 46 | +#[derive(Deserialize)] |
| 47 | +struct JwtClaims { |
| 48 | + /// JWT `sub` claim — canonical actor identity. |
| 49 | + sub: Option<String>, |
| 50 | + /// Custom claim: tenant identifier. |
| 51 | + tenant_id: Option<u64>, |
| 52 | + /// Custom claim: actor roles. Optional; defaults to empty. |
| 53 | + #[serde(default)] |
| 54 | + roles: Vec<String>, |
| 55 | +} |
| 56 | + |
| 57 | +impl JwtMiddleware { |
| 58 | + /// Extract `ActorContext` from a raw JWT token string. |
| 59 | + /// |
| 60 | + /// The token should be in the standard `header.payload.signature` |
| 61 | + /// format. Only the payload section is decoded and parsed. |
| 62 | + /// |
| 63 | + /// # Phase 1 Limitations |
| 64 | + /// |
| 65 | + /// - **No signature verification.** The signature section is ignored. |
| 66 | + /// Deploy behind a reverse proxy or API gateway that validates |
| 67 | + /// signatures before traffic reaches this layer. |
| 68 | + /// - **No expiry checking.** `exp` / `nbf` / `iat` are ignored. |
| 69 | + /// Phase 2 will enforce temporal validity. |
| 70 | + /// |
| 71 | + /// # Errors |
| 72 | + /// |
| 73 | + /// - `AuthError::MalformedToken` — token doesn't have 3 dot-separated parts. |
| 74 | + /// - `AuthError::InvalidBase64` — payload isn't valid base64url. |
| 75 | + /// - `AuthError::MissingSub` — payload JSON is missing the `sub` claim. |
| 76 | + /// - `AuthError::InvalidPayload` — payload JSON can't be parsed. |
| 77 | + pub fn extract_actor(token: &str) -> Result<ActorContext, AuthError> { |
| 78 | + // Split into header.payload.signature |
| 79 | + let parts: Vec<&str> = token.split('.').collect(); |
| 80 | + if parts.len() != 3 { |
| 81 | + return Err(AuthError::MalformedToken); |
| 82 | + } |
| 83 | + |
| 84 | + // Decode payload (middle part) |
| 85 | + let payload_bytes = base64url_decode(parts[1])?; |
| 86 | + |
| 87 | + // Parse JSON |
| 88 | + let claims: JwtClaims = serde_json::from_slice(&payload_bytes) |
| 89 | + .map_err(|e| AuthError::InvalidPayload(e.to_string()))?; |
| 90 | + |
| 91 | + // Extract required fields |
| 92 | + let actor_id = claims.sub.ok_or(AuthError::MissingSub)?; |
| 93 | + if actor_id.is_empty() { |
| 94 | + return Err(AuthError::MissingSub); |
| 95 | + } |
| 96 | + |
| 97 | + let tenant_id = claims.tenant_id.unwrap_or(0); |
| 98 | + |
| 99 | + Ok(ActorContext::new(actor_id, tenant_id, claims.roles)) |
| 100 | + } |
| 101 | + |
| 102 | + /// Extract `ActorContext` from an `Authorization: Bearer <token>` header value. |
| 103 | + /// |
| 104 | + /// Strips the `Bearer ` prefix if present, then delegates to `extract_actor`. |
| 105 | + pub fn extract_from_header(header_value: &str) -> Result<ActorContext, AuthError> { |
| 106 | + let token = header_value |
| 107 | + .strip_prefix("Bearer ") |
| 108 | + .or_else(|| header_value.strip_prefix("bearer ")) |
| 109 | + .unwrap_or(header_value); |
| 110 | + Self::extract_actor(token) |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +// ═══════════════════════════════════════════════════════════════════════════ |
| 115 | +// MINIMAL BASE64URL DECODER |
| 116 | +// ═══════════════════════════════════════════════════════════════════════════ |
| 117 | + |
| 118 | +/// Decode a base64url-encoded string (RFC 4648 §5) without padding. |
| 119 | +/// |
| 120 | +/// JWT payloads use URL-safe base64 without padding characters (`=`). |
| 121 | +/// This decoder handles both padded and unpadded inputs. |
| 122 | +/// |
| 123 | +/// ~40 lines, no external crate. Handles the full base64url alphabet |
| 124 | +/// (A-Z, a-z, 0-9, `-`, `_`). |
| 125 | +fn base64url_decode(input: &str) -> Result<Vec<u8>, AuthError> { |
| 126 | + // Base64url alphabet → 6-bit value |
| 127 | + fn char_to_sextet(c: u8) -> Result<u8, AuthError> { |
| 128 | + match c { |
| 129 | + b'A'..=b'Z' => Ok(c - b'A'), |
| 130 | + b'a'..=b'z' => Ok(c - b'a' + 26), |
| 131 | + b'0'..=b'9' => Ok(c - b'0' + 52), |
| 132 | + b'-' => Ok(62), |
| 133 | + b'_' => Ok(63), |
| 134 | + b'=' => Ok(0), // padding — value ignored |
| 135 | + _ => Err(AuthError::InvalidBase64), |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + // Strip padding for length calculation |
| 140 | + let stripped = input.trim_end_matches('='); |
| 141 | + let input_bytes = stripped.as_bytes(); |
| 142 | + let len = input_bytes.len(); |
| 143 | + |
| 144 | + if len == 0 { |
| 145 | + return Ok(Vec::new()); |
| 146 | + } |
| 147 | + |
| 148 | + // Validate: base64 produces 3 output bytes per 4 input chars. |
| 149 | + // Without padding: len%4 can be 0, 2, or 3 (never 1). |
| 150 | + if len % 4 == 1 { |
| 151 | + return Err(AuthError::InvalidBase64); |
| 152 | + } |
| 153 | + |
| 154 | + let out_len = len * 3 / 4; |
| 155 | + let mut out = Vec::with_capacity(out_len); |
| 156 | + |
| 157 | + // Process full 4-char groups |
| 158 | + let full_groups = len / 4; |
| 159 | + for i in 0..full_groups { |
| 160 | + let base = i * 4; |
| 161 | + let a = char_to_sextet(input_bytes[base])?; |
| 162 | + let b = char_to_sextet(input_bytes[base + 1])?; |
| 163 | + let c = char_to_sextet(input_bytes[base + 2])?; |
| 164 | + let d = char_to_sextet(input_bytes[base + 3])?; |
| 165 | + |
| 166 | + out.push((a << 2) | (b >> 4)); |
| 167 | + out.push((b << 4) | (c >> 2)); |
| 168 | + out.push((c << 6) | d); |
| 169 | + } |
| 170 | + |
| 171 | + // Handle remaining 2 or 3 chars |
| 172 | + let remainder = len % 4; |
| 173 | + if remainder >= 2 { |
| 174 | + let base = full_groups * 4; |
| 175 | + let a = char_to_sextet(input_bytes[base])?; |
| 176 | + let b = char_to_sextet(input_bytes[base + 1])?; |
| 177 | + out.push((a << 2) | (b >> 4)); |
| 178 | + |
| 179 | + if remainder == 3 { |
| 180 | + let c = char_to_sextet(input_bytes[base + 2])?; |
| 181 | + out.push((b << 4) | (c >> 2)); |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + Ok(out) |
| 186 | +} |
| 187 | + |
| 188 | +/// Encode bytes as base64url without padding (for test helpers). |
| 189 | +#[cfg(test)] |
| 190 | +fn base64url_encode(input: &[u8]) -> String { |
| 191 | + const ALPHABET: &[u8; 64] = |
| 192 | + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; |
| 193 | + |
| 194 | + let mut out = String::with_capacity((input.len() + 2) / 3 * 4); |
| 195 | + |
| 196 | + for chunk in input.chunks(3) { |
| 197 | + let b0 = chunk[0] as usize; |
| 198 | + let b1 = chunk.get(1).copied().unwrap_or(0) as usize; |
| 199 | + let b2 = chunk.get(2).copied().unwrap_or(0) as usize; |
| 200 | + |
| 201 | + out.push(ALPHABET[(b0 >> 2)] as char); |
| 202 | + out.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char); |
| 203 | + |
| 204 | + if chunk.len() > 1 { |
| 205 | + out.push(ALPHABET[((b1 & 0x0F) << 2) | (b2 >> 6)] as char); |
| 206 | + } |
| 207 | + if chunk.len() > 2 { |
| 208 | + out.push(ALPHABET[(b2 & 0x3F)] as char); |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + out |
| 213 | +} |
| 214 | + |
| 215 | +/// Build a minimal unsigned JWT from a JSON payload string (for tests). |
| 216 | +#[cfg(test)] |
| 217 | +fn make_test_jwt(payload_json: &str) -> String { |
| 218 | + let header = base64url_encode(b"{\"alg\":\"none\",\"typ\":\"JWT\"}"); |
| 219 | + let payload = base64url_encode(payload_json.as_bytes()); |
| 220 | + // No signature (Phase 1 doesn't verify) |
| 221 | + format!("{header}.{payload}.") |
| 222 | +} |
| 223 | + |
| 224 | +// ── Tests ───────────────────────────────────────────────────────────────────── |
| 225 | + |
| 226 | +#[cfg(test)] |
| 227 | +mod tests { |
| 228 | + use super::*; |
| 229 | + |
| 230 | + // ── Base64url decoder tests ── |
| 231 | + |
| 232 | + #[test] |
| 233 | + fn base64url_roundtrip() { |
| 234 | + let original = b"Hello, JWT world! \xF0\x9F\x94\x91"; |
| 235 | + let encoded = base64url_encode(original); |
| 236 | + let decoded = base64url_decode(&encoded).unwrap(); |
| 237 | + assert_eq!(decoded, original); |
| 238 | + } |
| 239 | + |
| 240 | + #[test] |
| 241 | + fn base64url_empty() { |
| 242 | + assert_eq!(base64url_decode("").unwrap(), Vec::<u8>::new()); |
| 243 | + } |
| 244 | + |
| 245 | + #[test] |
| 246 | + fn base64url_padding_tolerance() { |
| 247 | + // "Hello" base64url = "SGVsbG8" (no padding) or "SGVsbG8=" (with padding) |
| 248 | + let expected = b"Hello"; |
| 249 | + assert_eq!(base64url_decode("SGVsbG8").unwrap(), expected); |
| 250 | + assert_eq!(base64url_decode("SGVsbG8=").unwrap(), expected); |
| 251 | + } |
| 252 | + |
| 253 | + #[test] |
| 254 | + fn base64url_invalid_char() { |
| 255 | + assert_eq!(base64url_decode("!!!"), Err(AuthError::InvalidBase64)); |
| 256 | + } |
| 257 | + |
| 258 | + #[test] |
| 259 | + fn base64url_invalid_length() { |
| 260 | + // len%4 == 1 is invalid |
| 261 | + assert_eq!(base64url_decode("A"), Err(AuthError::InvalidBase64)); |
| 262 | + } |
| 263 | + |
| 264 | + // ── JWT extraction tests ── |
| 265 | + |
| 266 | + #[test] |
| 267 | + fn valid_jwt_full_claims() { |
| 268 | + let jwt = make_test_jwt( |
| 269 | + r#"{"sub":"user@example.com","tenant_id":42,"roles":["admin","viewer"]}"#, |
| 270 | + ); |
| 271 | + let ctx = JwtMiddleware::extract_actor(&jwt).unwrap(); |
| 272 | + assert_eq!(ctx.actor_id, "user@example.com"); |
| 273 | + assert_eq!(ctx.tenant_id, 42); |
| 274 | + assert_eq!(ctx.roles, vec!["admin", "viewer"]); |
| 275 | + assert!(ctx.is_admin()); |
| 276 | + } |
| 277 | + |
| 278 | + #[test] |
| 279 | + fn valid_jwt_minimal_claims() { |
| 280 | + let jwt = make_test_jwt(r#"{"sub":"bot-123","tenant_id":1}"#); |
| 281 | + let ctx = JwtMiddleware::extract_actor(&jwt).unwrap(); |
| 282 | + assert_eq!(ctx.actor_id, "bot-123"); |
| 283 | + assert_eq!(ctx.tenant_id, 1); |
| 284 | + assert!(ctx.roles.is_empty()); |
| 285 | + assert!(!ctx.is_admin()); |
| 286 | + } |
| 287 | + |
| 288 | + #[test] |
| 289 | + fn valid_jwt_empty_roles() { |
| 290 | + let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":0,"roles":[]}"#); |
| 291 | + let ctx = JwtMiddleware::extract_actor(&jwt).unwrap(); |
| 292 | + assert!(ctx.roles.is_empty()); |
| 293 | + } |
| 294 | + |
| 295 | + #[test] |
| 296 | + fn valid_jwt_missing_tenant_defaults_to_zero() { |
| 297 | + let jwt = make_test_jwt(r#"{"sub":"x"}"#); |
| 298 | + let ctx = JwtMiddleware::extract_actor(&jwt).unwrap(); |
| 299 | + assert_eq!(ctx.tenant_id, 0); |
| 300 | + } |
| 301 | + |
| 302 | + #[test] |
| 303 | + fn missing_sub_error() { |
| 304 | + let jwt = make_test_jwt(r#"{"tenant_id":1,"roles":["viewer"]}"#); |
| 305 | + assert_eq!( |
| 306 | + JwtMiddleware::extract_actor(&jwt), |
| 307 | + Err(AuthError::MissingSub) |
| 308 | + ); |
| 309 | + } |
| 310 | + |
| 311 | + #[test] |
| 312 | + fn empty_sub_error() { |
| 313 | + let jwt = make_test_jwt(r#"{"sub":"","tenant_id":1}"#); |
| 314 | + assert_eq!( |
| 315 | + JwtMiddleware::extract_actor(&jwt), |
| 316 | + Err(AuthError::MissingSub) |
| 317 | + ); |
| 318 | + } |
| 319 | + |
| 320 | + #[test] |
| 321 | + fn malformed_token_no_dots() { |
| 322 | + assert_eq!( |
| 323 | + JwtMiddleware::extract_actor("not-a-jwt"), |
| 324 | + Err(AuthError::MalformedToken) |
| 325 | + ); |
| 326 | + } |
| 327 | + |
| 328 | + #[test] |
| 329 | + fn malformed_token_two_parts() { |
| 330 | + assert_eq!( |
| 331 | + JwtMiddleware::extract_actor("header.payload"), |
| 332 | + Err(AuthError::MalformedToken) |
| 333 | + ); |
| 334 | + } |
| 335 | + |
| 336 | + #[test] |
| 337 | + fn malformed_token_four_parts() { |
| 338 | + assert_eq!( |
| 339 | + JwtMiddleware::extract_actor("a.b.c.d"), |
| 340 | + Err(AuthError::MalformedToken) |
| 341 | + ); |
| 342 | + } |
| 343 | + |
| 344 | + #[test] |
| 345 | + fn invalid_base64_payload() { |
| 346 | + // Valid structure (3 parts) but middle part is bad base64 |
| 347 | + assert!(matches!( |
| 348 | + JwtMiddleware::extract_actor("header.!!!invalid.sig"), |
| 349 | + Err(AuthError::InvalidBase64) |
| 350 | + )); |
| 351 | + } |
| 352 | + |
| 353 | + #[test] |
| 354 | + fn invalid_json_payload() { |
| 355 | + let header = base64url_encode(b"{}"); |
| 356 | + let payload = base64url_encode(b"not json at all {{{"); |
| 357 | + let token = format!("{header}.{payload}."); |
| 358 | + assert!(matches!( |
| 359 | + JwtMiddleware::extract_actor(&token), |
| 360 | + Err(AuthError::InvalidPayload(_)) |
| 361 | + )); |
| 362 | + } |
| 363 | + |
| 364 | + #[test] |
| 365 | + fn extract_from_bearer_header() { |
| 366 | + let jwt = make_test_jwt(r#"{"sub":"user@test.com","tenant_id":7}"#); |
| 367 | + let header = format!("Bearer {jwt}"); |
| 368 | + let ctx = JwtMiddleware::extract_from_header(&header).unwrap(); |
| 369 | + assert_eq!(ctx.actor_id, "user@test.com"); |
| 370 | + assert_eq!(ctx.tenant_id, 7); |
| 371 | + } |
| 372 | + |
| 373 | + #[test] |
| 374 | + fn extract_from_header_lowercase_bearer() { |
| 375 | + let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#); |
| 376 | + let header = format!("bearer {jwt}"); |
| 377 | + let ctx = JwtMiddleware::extract_from_header(&header).unwrap(); |
| 378 | + assert_eq!(ctx.actor_id, "x"); |
| 379 | + } |
| 380 | + |
| 381 | + #[test] |
| 382 | + fn extract_from_header_no_prefix() { |
| 383 | + let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#); |
| 384 | + let ctx = JwtMiddleware::extract_from_header(&jwt).unwrap(); |
| 385 | + assert_eq!(ctx.actor_id, "x"); |
| 386 | + } |
| 387 | +} |
0 commit comments