Skip to content

Commit c7310ec

Browse files
committed
feat(LF-3): JWT middleware + DataFusion RLS rewriter (auth module)
LF-3 / DM-7 — callcenter [auth] feature, Phase 1: ActorContext (contract/auth.rs): ActorContext { actor_id: String, tenant_id: TenantId, roles } AuthError enum for extraction failures. Zero-dep, in contract crate for cross-consumer use. JwtMiddleware (callcenter/auth.rs, feature = "auth"): extract_actor(token) — base64-decode JWT payload, parse JSON, extract sub/tenant_id/roles into ActorContext. Phase 1: no signature verification (deployment-specific). Minimal base64url decoder (~30 lines, no external dep). RlsRewriter (callcenter/rls.rs, feature = "query"): DataFusion OptimizerRule that injects tenant_id + actor_id predicates on TableScan nodes in the LogicalPlan. Admin role skips actor_id filter. Recursive plan tree walking. Scope boundaries per SMB REQUEST at bf7c05e: - IN: JWT → ActorContext → LogicalPlan RLS rewrite - OUT: connectors, sharding, per-property marking All tests pass. Workspace cargo check clean. https://claude.ai/code/session_01SbYsmmbPf9YQuYbHZN52Zh
1 parent 56f2695 commit c7310ec

6 files changed

Lines changed: 828 additions & 4 deletions

File tree

crates/lance-graph-callcenter/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,8 @@ persist = ["dep:arrow", "dep:lance"]
3333
query = ["dep:datafusion", "dep:arrow"]
3434
realtime = ["dep:tokio", "dep:tokio-tungstenite", "dep:serde", "dep:serde_json"]
3535
serve = ["realtime", "query", "dep:axum", "dep:tower-http"]
36-
auth = ["dep:serde", "dep:serde_json"]
36+
auth = ["query", "dep:serde", "dep:serde_json"]
3737
full = ["persist", "query", "realtime", "serve", "auth"]
38+
39+
[dev-dependencies]
40+
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
//! DM-7 — JWT extraction middleware + `ActorContext` population.
2+
//!
3+
//! **Phase 1 (this file):** Extract and decode JWT payload (base64),
4+
//! populate `ActorContext`. No signature verification — that requires
5+
//! a JWK endpoint or static key, which is deployment-specific.
6+
//!
7+
//! **Phase 2 (future):** Plug in real verification via a `JwkSetProvider`
8+
//! trait. The `JwtMiddleware::extract_actor` API won't change — only the
9+
//! internal verification step gets wired.
10+
//!
11+
//! # JWT Payload Shape (expected claims)
12+
//!
13+
//! ```json
14+
//! {
15+
//! "sub": "user@example.com",
16+
//! "tenant_id": 42,
17+
//! "roles": ["viewer", "editor"]
18+
//! }
19+
//! ```
20+
//!
21+
//! - `sub` (required) — maps to `ActorContext.actor_id`.
22+
//! - `tenant_id` (required) — maps to `ActorContext.tenant_id` (`TenantId = u64`).
23+
//! - `roles` (optional) — maps to `ActorContext.roles`. Defaults to `[]`.
24+
//!
25+
//! # Zero New Dependencies
26+
//!
27+
//! Uses `serde` + `serde_json` (already gated under `[auth]` feature).
28+
//! Base64 URL-safe decoding is implemented inline (~40 lines) — no
29+
//! `base64` crate, no `jsonwebtoken` crate.
30+
//!
31+
//! Plan: `.claude/plans/callcenter-membrane-v1.md` § DM-7
32+
33+
use lance_graph_contract::auth::{ActorContext, AuthError};
34+
use serde::Deserialize;
35+
36+
/// JWT extraction middleware.
37+
///
38+
/// Phase 1: base64-decode the payload section of a JWT and extract
39+
/// `sub`, `tenant_id`, and `roles` into an `ActorContext`.
40+
///
41+
/// No signature verification in Phase 1 — the token is trusted as-is.
42+
/// Phase 2 will add a `JwkSetProvider` trait for real verification.
43+
pub struct JwtMiddleware;
44+
45+
/// Deserialization target for the JWT payload claims we care about.
46+
#[derive(Deserialize)]
47+
struct JwtClaims {
48+
/// JWT `sub` claim — canonical actor identity.
49+
sub: Option<String>,
50+
/// Custom claim: tenant identifier.
51+
tenant_id: Option<u64>,
52+
/// Custom claim: actor roles. Optional; defaults to empty.
53+
#[serde(default)]
54+
roles: Vec<String>,
55+
}
56+
57+
impl JwtMiddleware {
58+
/// Extract `ActorContext` from a raw JWT token string.
59+
///
60+
/// The token should be in the standard `header.payload.signature`
61+
/// format. Only the payload section is decoded and parsed.
62+
///
63+
/// # Phase 1 Limitations
64+
///
65+
/// - **No signature verification.** The signature section is ignored.
66+
/// Deploy behind a reverse proxy or API gateway that validates
67+
/// signatures before traffic reaches this layer.
68+
/// - **No expiry checking.** `exp` / `nbf` / `iat` are ignored.
69+
/// Phase 2 will enforce temporal validity.
70+
///
71+
/// # Errors
72+
///
73+
/// - `AuthError::MalformedToken` — token doesn't have 3 dot-separated parts.
74+
/// - `AuthError::InvalidBase64` — payload isn't valid base64url.
75+
/// - `AuthError::MissingSub` — payload JSON is missing the `sub` claim.
76+
/// - `AuthError::InvalidPayload` — payload JSON can't be parsed.
77+
pub fn extract_actor(token: &str) -> Result<ActorContext, AuthError> {
78+
// Split into header.payload.signature
79+
let parts: Vec<&str> = token.split('.').collect();
80+
if parts.len() != 3 {
81+
return Err(AuthError::MalformedToken);
82+
}
83+
84+
// Decode payload (middle part)
85+
let payload_bytes = base64url_decode(parts[1])?;
86+
87+
// Parse JSON
88+
let claims: JwtClaims = serde_json::from_slice(&payload_bytes)
89+
.map_err(|e| AuthError::InvalidPayload(e.to_string()))?;
90+
91+
// Extract required fields
92+
let actor_id = claims.sub.ok_or(AuthError::MissingSub)?;
93+
if actor_id.is_empty() {
94+
return Err(AuthError::MissingSub);
95+
}
96+
97+
let tenant_id = claims.tenant_id.unwrap_or(0);
98+
99+
Ok(ActorContext::new(actor_id, tenant_id, claims.roles))
100+
}
101+
102+
/// Extract `ActorContext` from an `Authorization: Bearer <token>` header value.
103+
///
104+
/// Strips the `Bearer ` prefix if present, then delegates to `extract_actor`.
105+
pub fn extract_from_header(header_value: &str) -> Result<ActorContext, AuthError> {
106+
let token = header_value
107+
.strip_prefix("Bearer ")
108+
.or_else(|| header_value.strip_prefix("bearer "))
109+
.unwrap_or(header_value);
110+
Self::extract_actor(token)
111+
}
112+
}
113+
114+
// ═══════════════════════════════════════════════════════════════════════════
115+
// MINIMAL BASE64URL DECODER
116+
// ═══════════════════════════════════════════════════════════════════════════
117+
118+
/// Decode a base64url-encoded string (RFC 4648 §5) without padding.
119+
///
120+
/// JWT payloads use URL-safe base64 without padding characters (`=`).
121+
/// This decoder handles both padded and unpadded inputs.
122+
///
123+
/// ~40 lines, no external crate. Handles the full base64url alphabet
124+
/// (A-Z, a-z, 0-9, `-`, `_`).
125+
fn base64url_decode(input: &str) -> Result<Vec<u8>, AuthError> {
126+
// Base64url alphabet → 6-bit value
127+
fn char_to_sextet(c: u8) -> Result<u8, AuthError> {
128+
match c {
129+
b'A'..=b'Z' => Ok(c - b'A'),
130+
b'a'..=b'z' => Ok(c - b'a' + 26),
131+
b'0'..=b'9' => Ok(c - b'0' + 52),
132+
b'-' => Ok(62),
133+
b'_' => Ok(63),
134+
b'=' => Ok(0), // padding — value ignored
135+
_ => Err(AuthError::InvalidBase64),
136+
}
137+
}
138+
139+
// Strip padding for length calculation
140+
let stripped = input.trim_end_matches('=');
141+
let input_bytes = stripped.as_bytes();
142+
let len = input_bytes.len();
143+
144+
if len == 0 {
145+
return Ok(Vec::new());
146+
}
147+
148+
// Validate: base64 produces 3 output bytes per 4 input chars.
149+
// Without padding: len%4 can be 0, 2, or 3 (never 1).
150+
if len % 4 == 1 {
151+
return Err(AuthError::InvalidBase64);
152+
}
153+
154+
let out_len = len * 3 / 4;
155+
let mut out = Vec::with_capacity(out_len);
156+
157+
// Process full 4-char groups
158+
let full_groups = len / 4;
159+
for i in 0..full_groups {
160+
let base = i * 4;
161+
let a = char_to_sextet(input_bytes[base])?;
162+
let b = char_to_sextet(input_bytes[base + 1])?;
163+
let c = char_to_sextet(input_bytes[base + 2])?;
164+
let d = char_to_sextet(input_bytes[base + 3])?;
165+
166+
out.push((a << 2) | (b >> 4));
167+
out.push((b << 4) | (c >> 2));
168+
out.push((c << 6) | d);
169+
}
170+
171+
// Handle remaining 2 or 3 chars
172+
let remainder = len % 4;
173+
if remainder >= 2 {
174+
let base = full_groups * 4;
175+
let a = char_to_sextet(input_bytes[base])?;
176+
let b = char_to_sextet(input_bytes[base + 1])?;
177+
out.push((a << 2) | (b >> 4));
178+
179+
if remainder == 3 {
180+
let c = char_to_sextet(input_bytes[base + 2])?;
181+
out.push((b << 4) | (c >> 2));
182+
}
183+
}
184+
185+
Ok(out)
186+
}
187+
188+
/// Encode bytes as base64url without padding (for test helpers).
189+
#[cfg(test)]
190+
fn base64url_encode(input: &[u8]) -> String {
191+
const ALPHABET: &[u8; 64] =
192+
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
193+
194+
let mut out = String::with_capacity((input.len() + 2) / 3 * 4);
195+
196+
for chunk in input.chunks(3) {
197+
let b0 = chunk[0] as usize;
198+
let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
199+
let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
200+
201+
out.push(ALPHABET[(b0 >> 2)] as char);
202+
out.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
203+
204+
if chunk.len() > 1 {
205+
out.push(ALPHABET[((b1 & 0x0F) << 2) | (b2 >> 6)] as char);
206+
}
207+
if chunk.len() > 2 {
208+
out.push(ALPHABET[(b2 & 0x3F)] as char);
209+
}
210+
}
211+
212+
out
213+
}
214+
215+
/// Build a minimal unsigned JWT from a JSON payload string (for tests).
216+
#[cfg(test)]
217+
fn make_test_jwt(payload_json: &str) -> String {
218+
let header = base64url_encode(b"{\"alg\":\"none\",\"typ\":\"JWT\"}");
219+
let payload = base64url_encode(payload_json.as_bytes());
220+
// No signature (Phase 1 doesn't verify)
221+
format!("{header}.{payload}.")
222+
}
223+
224+
// ── Tests ─────────────────────────────────────────────────────────────────────
225+
226+
#[cfg(test)]
227+
mod tests {
228+
use super::*;
229+
230+
// ── Base64url decoder tests ──
231+
232+
#[test]
233+
fn base64url_roundtrip() {
234+
let original = b"Hello, JWT world! \xF0\x9F\x94\x91";
235+
let encoded = base64url_encode(original);
236+
let decoded = base64url_decode(&encoded).unwrap();
237+
assert_eq!(decoded, original);
238+
}
239+
240+
#[test]
241+
fn base64url_empty() {
242+
assert_eq!(base64url_decode("").unwrap(), Vec::<u8>::new());
243+
}
244+
245+
#[test]
246+
fn base64url_padding_tolerance() {
247+
// "Hello" base64url = "SGVsbG8" (no padding) or "SGVsbG8=" (with padding)
248+
let expected = b"Hello";
249+
assert_eq!(base64url_decode("SGVsbG8").unwrap(), expected);
250+
assert_eq!(base64url_decode("SGVsbG8=").unwrap(), expected);
251+
}
252+
253+
#[test]
254+
fn base64url_invalid_char() {
255+
assert_eq!(base64url_decode("!!!"), Err(AuthError::InvalidBase64));
256+
}
257+
258+
#[test]
259+
fn base64url_invalid_length() {
260+
// len%4 == 1 is invalid
261+
assert_eq!(base64url_decode("A"), Err(AuthError::InvalidBase64));
262+
}
263+
264+
// ── JWT extraction tests ──
265+
266+
#[test]
267+
fn valid_jwt_full_claims() {
268+
let jwt = make_test_jwt(
269+
r#"{"sub":"user@example.com","tenant_id":42,"roles":["admin","viewer"]}"#,
270+
);
271+
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
272+
assert_eq!(ctx.actor_id, "user@example.com");
273+
assert_eq!(ctx.tenant_id, 42);
274+
assert_eq!(ctx.roles, vec!["admin", "viewer"]);
275+
assert!(ctx.is_admin());
276+
}
277+
278+
#[test]
279+
fn valid_jwt_minimal_claims() {
280+
let jwt = make_test_jwt(r#"{"sub":"bot-123","tenant_id":1}"#);
281+
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
282+
assert_eq!(ctx.actor_id, "bot-123");
283+
assert_eq!(ctx.tenant_id, 1);
284+
assert!(ctx.roles.is_empty());
285+
assert!(!ctx.is_admin());
286+
}
287+
288+
#[test]
289+
fn valid_jwt_empty_roles() {
290+
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":0,"roles":[]}"#);
291+
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
292+
assert!(ctx.roles.is_empty());
293+
}
294+
295+
#[test]
296+
fn valid_jwt_missing_tenant_defaults_to_zero() {
297+
let jwt = make_test_jwt(r#"{"sub":"x"}"#);
298+
let ctx = JwtMiddleware::extract_actor(&jwt).unwrap();
299+
assert_eq!(ctx.tenant_id, 0);
300+
}
301+
302+
#[test]
303+
fn missing_sub_error() {
304+
let jwt = make_test_jwt(r#"{"tenant_id":1,"roles":["viewer"]}"#);
305+
assert_eq!(
306+
JwtMiddleware::extract_actor(&jwt),
307+
Err(AuthError::MissingSub)
308+
);
309+
}
310+
311+
#[test]
312+
fn empty_sub_error() {
313+
let jwt = make_test_jwt(r#"{"sub":"","tenant_id":1}"#);
314+
assert_eq!(
315+
JwtMiddleware::extract_actor(&jwt),
316+
Err(AuthError::MissingSub)
317+
);
318+
}
319+
320+
#[test]
321+
fn malformed_token_no_dots() {
322+
assert_eq!(
323+
JwtMiddleware::extract_actor("not-a-jwt"),
324+
Err(AuthError::MalformedToken)
325+
);
326+
}
327+
328+
#[test]
329+
fn malformed_token_two_parts() {
330+
assert_eq!(
331+
JwtMiddleware::extract_actor("header.payload"),
332+
Err(AuthError::MalformedToken)
333+
);
334+
}
335+
336+
#[test]
337+
fn malformed_token_four_parts() {
338+
assert_eq!(
339+
JwtMiddleware::extract_actor("a.b.c.d"),
340+
Err(AuthError::MalformedToken)
341+
);
342+
}
343+
344+
#[test]
345+
fn invalid_base64_payload() {
346+
// Valid structure (3 parts) but middle part is bad base64
347+
assert!(matches!(
348+
JwtMiddleware::extract_actor("header.!!!invalid.sig"),
349+
Err(AuthError::InvalidBase64)
350+
));
351+
}
352+
353+
#[test]
354+
fn invalid_json_payload() {
355+
let header = base64url_encode(b"{}");
356+
let payload = base64url_encode(b"not json at all {{{");
357+
let token = format!("{header}.{payload}.");
358+
assert!(matches!(
359+
JwtMiddleware::extract_actor(&token),
360+
Err(AuthError::InvalidPayload(_))
361+
));
362+
}
363+
364+
#[test]
365+
fn extract_from_bearer_header() {
366+
let jwt = make_test_jwt(r#"{"sub":"user@test.com","tenant_id":7}"#);
367+
let header = format!("Bearer {jwt}");
368+
let ctx = JwtMiddleware::extract_from_header(&header).unwrap();
369+
assert_eq!(ctx.actor_id, "user@test.com");
370+
assert_eq!(ctx.tenant_id, 7);
371+
}
372+
373+
#[test]
374+
fn extract_from_header_lowercase_bearer() {
375+
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#);
376+
let header = format!("bearer {jwt}");
377+
let ctx = JwtMiddleware::extract_from_header(&header).unwrap();
378+
assert_eq!(ctx.actor_id, "x");
379+
}
380+
381+
#[test]
382+
fn extract_from_header_no_prefix() {
383+
let jwt = make_test_jwt(r#"{"sub":"x","tenant_id":1}"#);
384+
let ctx = JwtMiddleware::extract_from_header(&jwt).unwrap();
385+
assert_eq!(ctx.actor_id, "x");
386+
}
387+
}

0 commit comments

Comments
 (0)