Skip to content

Commit dbc49b1

Browse files
authored
Add AuthCtx to ReducerContext for rust (#3288)
# Description of Changes This exposes client credentials in reducer calls for rust. # API and ABI breaking changes API Changes: The main API change is the addition of `AuthCtx` and the `sender_auth` in `ReducerContext`. This also adds JwtClaims, which has some helpers for getting commonly used claims. ABI Changes: This adds one new functions `get_jwt`. This uses `st_connection_credentials` to look up the credentials associated with a connection id. This adds ABI version 10.2. # Expected complexity level and risk 2. This adds new ABI functions # Testing I've done some manual testing with modified versions of the quickstart. We should add some examples that use the new API.
1 parent 542d26d commit dbc49b1

39 files changed

Lines changed: 1793 additions & 1654 deletions

Cargo.lock

Lines changed: 1208 additions & 1317 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/auth/src/identity.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ pub struct ConnectionAuthCtx {
1515
impl TryFrom<SpacetimeIdentityClaims> for ConnectionAuthCtx {
1616
type Error = anyhow::Error;
1717
fn try_from(claims: SpacetimeIdentityClaims) -> Result<Self, Self::Error> {
18-
let payload =
19-
serde_json::to_string(&claims).map_err(|e| anyhow::anyhow!("Failed to serialize claims: {}", e))?;
18+
let payload = serde_json::to_string(&claims).map_err(|e| anyhow::anyhow!("Failed to serialize claims: {e}"))?;
2019
Ok(ConnectionAuthCtx {
2120
claims,
2221
jwt_payload: payload,
@@ -111,9 +110,7 @@ impl TryInto<SpacetimeIdentityClaims> for IncomingClaims {
111110
if let Some(token_identity) = self.identity {
112111
if token_identity != computed_identity {
113112
return Err(anyhow::anyhow!(
114-
"Identity mismatch: token identity {:?} does not match computed identity {:?}",
115-
token_identity,
116-
computed_identity,
113+
"Identity mismatch: token identity {token_identity:?} does not match computed identity {computed_identity:?}",
117114
));
118115
}
119116
}

crates/bindings-sys/src/lib.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ pub mod raw {
588588
///
589589
/// - `out_ptr` is NULL or `out` is not in bounds of WASM memory.
590590
pub fn identity(out_ptr: *mut u8);
591+
591592
}
592593

593594
// See comment on previous `extern "C"` block re: ABI version.
@@ -619,6 +620,31 @@ pub mod raw {
619620
pub fn bytes_source_remaining_length(source: BytesSource, out: *mut u32) -> i16;
620621
}
621622

623+
// See comment on previous `extern "C"` block re: ABI version.
624+
#[link(wasm_import_module = "spacetime_10.2")]
625+
extern "C" {
626+
/// Finds the JWT payload associated with `connection_id`.
627+
/// A `[ByteSourceId]` for the payload will be written to `target_ptr`.
628+
/// If nothing is found for the connection, `[ByteSourceId::INVALID]` (zero) is written to `target_ptr`.
629+
///
630+
/// This must be called inside a transaction (because it reads from a system table).
631+
///
632+
/// # Errors
633+
///
634+
/// Returns an error:
635+
///
636+
/// - `NOT_IN_TRANSACTION`, when called outside of a transaction.
637+
///
638+
/// # Traps
639+
///
640+
/// Traps if:
641+
///
642+
/// - `connection_id` does not point to a valid little-endian `ConnectionId`.
643+
/// - `target_ptr` is NULL or `target_ptr[..size_of::<u32>()]` is not in bounds of WASM memory.
644+
/// - The `ByteSourceId` to be written to `target_ptr` would overflow [`u32::MAX`].
645+
pub fn get_jwt(connection_id_ptr: *const u8, bytes_source_id: *mut BytesSource) -> u16;
646+
}
647+
622648
/// What strategy does the database index use?
623649
///
624650
/// See also: <https://www.postgresql.org/docs/current/sql-createindex.html>
@@ -1118,6 +1144,29 @@ pub fn identity() -> [u8; 32] {
11181144
buf
11191145
}
11201146

1147+
/// Finds the JWT payload associated with `connection_id`.
1148+
/// If nothing is found for the connection, this returns None.
1149+
/// If a payload is found, this will return a valid [`raw::BytesSource`].
1150+
///
1151+
/// This must be called inside a transaction (because it reads from a system table).
1152+
///
1153+
/// # Errors
1154+
///
1155+
/// This panics on any error. You can see details about errors in [`raw::get_jwt`].
1156+
#[inline]
1157+
pub fn get_jwt(connection_id: [u8; 16]) -> Option<raw::BytesSource> {
1158+
let source = unsafe {
1159+
call(|out| raw::get_jwt(connection_id.as_ptr(), out))
1160+
.unwrap_or_else(|errno: Errno| panic!("Error getting jwt: {errno}"))
1161+
};
1162+
1163+
if source == raw::BytesSource::INVALID {
1164+
None // No JWT found.
1165+
} else {
1166+
Some(source)
1167+
}
1168+
}
1169+
11211170
pub struct RowIter {
11221171
raw: raw::RowIter,
11231172
}

crates/bindings/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ rand08 = { workspace = true, optional = true }
3535
# if someone tries to use rand's ThreadRng, it will fail to link
3636
# because no one defined __getrandom_custom
3737
getrandom02 = { workspace = true, optional = true, features = ["custom"] }
38+
serde_json.workspace = true
3839

3940
[dev-dependencies]
4041
insta.workspace = true

crates/bindings/src/lib.rs

Lines changed: 172 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ pub mod rt;
1212
#[doc(hidden)]
1313
pub mod table;
1414

15-
use spacetimedb_lib::bsatn;
16-
use std::cell::RefCell;
17-
1815
pub use log;
1916
#[cfg(feature = "rand")]
2017
pub use rand08 as rand;
18+
use spacetimedb_lib::bsatn;
19+
use std::cell::LazyCell;
20+
use std::cell::{OnceCell, RefCell};
21+
use std::ops::Deref;
2122

2223
#[cfg(feature = "unstable")]
2324
pub use client_visibility_filter::Filter;
@@ -751,6 +752,8 @@ pub struct ReducerContext {
751752
/// See the [`#[table]`](macro@crate::table) macro for more information.
752753
pub db: Local,
753754

755+
sender_auth: AuthCtx,
756+
754757
#[cfg(feature = "rand08")]
755758
rng: std::cell::OnceCell<StdbRng>,
756759
}
@@ -763,10 +766,33 @@ impl ReducerContext {
763766
sender: Identity::__dummy(),
764767
timestamp: Timestamp::UNIX_EPOCH,
765768
connection_id: None,
769+
sender_auth: AuthCtx::internal(),
766770
rng: std::cell::OnceCell::new(),
767771
}
768772
}
769773

774+
#[doc(hidden)]
775+
fn new(db: Local, sender: Identity, connection_id: Option<ConnectionId>, timestamp: Timestamp) -> Self {
776+
let sender_auth = match connection_id {
777+
Some(cid) => AuthCtx::from_connection_id(cid),
778+
None => AuthCtx::internal(),
779+
};
780+
Self {
781+
db,
782+
sender,
783+
timestamp,
784+
connection_id,
785+
sender_auth,
786+
#[cfg(feature = "rand08")]
787+
rng: std::cell::OnceCell::new(),
788+
}
789+
}
790+
791+
/// Returns the authorization information for the caller of this reducer.
792+
pub fn sender_auth(&self) -> &AuthCtx {
793+
&self.sender_auth
794+
}
795+
770796
/// Read the current module's [`Identity`].
771797
pub fn identity(&self) -> Identity {
772798
// Hypothetically, we *could* read the module identity out of the system tables.
@@ -829,6 +855,115 @@ impl DbContext for ReducerContext {
829855
#[non_exhaustive]
830856
pub struct Local {}
831857

858+
#[non_exhaustive]
859+
pub struct JwtClaims {
860+
payload: String,
861+
parsed: OnceCell<serde_json::Value>,
862+
audience: OnceCell<Vec<String>>,
863+
}
864+
865+
/// Authentication information for the caller of a reducer.
866+
pub struct AuthCtx {
867+
is_internal: bool,
868+
// NOTE(jsdt): cannot directly use a LazyLock without making this struct generic.
869+
jwt: Box<dyn Deref<Target = Option<JwtClaims>>>,
870+
}
871+
872+
impl AuthCtx {
873+
fn new(is_internal: bool, jwt_fn: impl FnOnce() -> Option<JwtClaims> + 'static) -> Self {
874+
AuthCtx {
875+
is_internal,
876+
jwt: Box::new(LazyCell::new(jwt_fn)),
877+
}
878+
}
879+
880+
/// Create an [`AuthCtx`] for an internal call, with no JWT.
881+
/// This represents a scheduled reducer.
882+
pub fn internal() -> AuthCtx {
883+
Self::new(true, || None)
884+
}
885+
886+
/// Creates an [`AuthCtx`] using the json claims from a JWT.
887+
/// This can be used to write unit tests.
888+
pub fn from_jwt_payload(jwt_payload: String) -> AuthCtx {
889+
Self::new(false, move || Some(JwtClaims::new(jwt_payload)))
890+
}
891+
892+
/// Creates an [`AuthCtx`] that reads the JWT for the given connection id.
893+
fn from_connection_id(connection_id: ConnectionId) -> AuthCtx {
894+
Self::new(false, move || rt::get_jwt(connection_id).map(JwtClaims::new))
895+
}
896+
897+
/// Returns whether this reducer was spawned from inside the database.
898+
pub fn is_internal(&self) -> bool {
899+
self.is_internal
900+
}
901+
902+
/// Check if there is a JWT without loading it.
903+
/// If [`AuthCtx::is_internal`] is true, this will return false.
904+
pub fn has_jwt(&self) -> bool {
905+
self.jwt.is_some()
906+
}
907+
908+
/// Load the jwt.
909+
pub fn jwt(&self) -> Option<&JwtClaims> {
910+
self.jwt.as_ref().deref().as_ref()
911+
}
912+
}
913+
914+
impl JwtClaims {
915+
fn new(jwt: String) -> Self {
916+
Self {
917+
payload: jwt,
918+
parsed: OnceCell::new(),
919+
audience: OnceCell::new(),
920+
}
921+
}
922+
923+
fn get_parsed(&self) -> &serde_json::Value {
924+
self.parsed
925+
.get_or_init(|| serde_json::from_str(&self.payload).expect("Failed to parse JWT payload"))
926+
}
927+
928+
/// Returns the tokens subject, from the sub claim.
929+
pub fn subject(&self) -> &str {
930+
self.get_parsed()
931+
.get("sub")
932+
.expect("Missing 'sub' claim")
933+
.as_str()
934+
.expect("Token 'sub' claim is not a string")
935+
}
936+
937+
/// Returns the issuer for these credentials, from the iss claim.
938+
pub fn issuer(&self) -> &str {
939+
self.get_parsed().get("iss").unwrap().as_str().unwrap()
940+
}
941+
942+
fn extract_audience(&self) -> Vec<String> {
943+
let aud = self.get_parsed().get("aud").unwrap();
944+
match aud {
945+
serde_json::Value::String(s) => vec![s.clone()],
946+
serde_json::Value::Array(arr) => arr.iter().filter_map(|v| v.as_str().map(String::from)).collect(),
947+
_ => panic!("Unexpected type for 'aud' claim in JWT"),
948+
}
949+
}
950+
951+
/// Returns the audience for these credentials, from the aud claim.
952+
pub fn audience(&self) -> &[String] {
953+
self.audience.get_or_init(|| self.extract_audience())
954+
}
955+
956+
/// Returns the identity for these credentials, which is
957+
/// based on the iss and sub claims.
958+
pub fn identity(&self) -> Identity {
959+
Identity::from_claims(self.issuer(), self.subject())
960+
}
961+
962+
/// Get the whole JWT payload as a json string.
963+
pub fn raw_payload(&self) -> &str {
964+
&self.payload
965+
}
966+
}
832967
/// The read-only version of [`Local`]
833968
#[non_exhaustive]
834969
pub struct LocalReadOnly {}
@@ -937,3 +1072,37 @@ macro_rules! __volatile_nonatomic_schedule_immediate_impl {
9371072
}
9381073
};
9391074
}
1075+
1076+
#[cfg(test)]
1077+
mod tests {
1078+
use super::*;
1079+
1080+
#[test]
1081+
fn parse_single_audience() {
1082+
let example_payload = r#"
1083+
{
1084+
"iss": "https://securetoken.google.com/my-project-id",
1085+
"aud": "my-project-id",
1086+
"auth_time": 1695560000,
1087+
"user_id": "abc123XYZ",
1088+
"sub": "abc123XYZ",
1089+
"iat": 1695560100,
1090+
"exp": 1695563700,
1091+
"email": "user@example.com",
1092+
"email_verified": true,
1093+
"firebase": {
1094+
"identities": {
1095+
"email": ["user@example.com"]
1096+
},
1097+
"sign_in_provider": "password"
1098+
},
1099+
"name": "Jane Doe",
1100+
"picture": "https://lh3.googleusercontent.com/a-/profile.jpg"
1101+
}
1102+
"#;
1103+
let auth = AuthCtx::from_jwt_payload(example_payload.to_string());
1104+
let audience = auth.jwt().unwrap().audience();
1105+
assert_eq!(audience.len(), 1);
1106+
assert_eq!(audience, &["my-project-id".to_string()]);
1107+
}
1108+
}

crates/bindings/src/rt.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -488,13 +488,7 @@ extern "C" fn __call_reducer__(
488488

489489
// Assemble the `ReducerContext`.
490490
let timestamp = Timestamp::from_micros_since_unix_epoch(timestamp as i64);
491-
let ctx = ReducerContext {
492-
db: crate::Local {},
493-
sender,
494-
timestamp,
495-
connection_id: conn_id,
496-
rng: std::cell::OnceCell::new(),
497-
};
491+
let ctx = ReducerContext::new(crate::Local {}, sender, conn_id, timestamp);
498492

499493
// Fetch reducer function.
500494
let reducers = REDUCERS.get().unwrap();
@@ -531,6 +525,17 @@ fn with_read_args<R>(args: BytesSource, logic: impl FnOnce(&[u8]) -> R) -> R {
531525
const NO_SPACE: u16 = errno::NO_SPACE.get();
532526
const NO_SUCH_BYTES: u16 = errno::NO_SUCH_BYTES.get();
533527

528+
/// Look up the jwt associated with `connection_id`.
529+
pub fn get_jwt(connection_id: ConnectionId) -> Option<String> {
530+
let mut buf = IterBuf::take();
531+
let source = sys::get_jwt(connection_id.as_le_byte_array())?;
532+
if source == BytesSource::INVALID {
533+
return None;
534+
}
535+
read_bytes_source_into(source, &mut buf);
536+
Some(std::str::from_utf8(&buf).unwrap().to_string())
537+
}
538+
534539
/// Read `source` from the host fully into `buf`.
535540
fn read_bytes_source_into(source: BytesSource, buf: &mut Vec<u8>) {
536541
const INVALID: i16 = NO_SUCH_BYTES as i16;
@@ -565,8 +570,8 @@ fn read_bytes_source_into(source: BytesSource, buf: &mut Vec<u8>) {
565570
let buf_ptr = buf_ptr.as_mut_ptr().cast();
566571
let ret = unsafe { sys::raw::bytes_source_read(source, buf_ptr, &mut buf_len) };
567572
if ret <= 0 {
568-
// SAFETY: `bytes_source_read` just appended `spare_len` bytes to `buf`.
569-
unsafe { buf.set_len(buf.len() + spare_len) };
573+
// SAFETY: `bytes_source_read` just appended `buf_len` bytes to `buf`.
574+
unsafe { buf.set_len(buf.len() + buf_len) };
570575
}
571576
match ret {
572577
// Host side source exhausted, we're done.

crates/bindings/tests/snapshots/deps__spacetimedb_bindings_dependencies.snap

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
source: crates/bindings/tests/deps.rs
33
expression: "cargo tree -p spacetimedb -e no-dev --color never --target wasm32-unknown-unknown -f {lib}"
44
---
5-
total crates: 60
5+
total crates: 66
66
spacetimedb
77
├── bytemuck
88
├── derive_more
@@ -29,6 +29,11 @@ spacetimedb
2929
│ │ └── getrandom (*)
3030
│ └── rand_core (*)
3131
├── scoped_tls
32+
├── serde_json
33+
│ ├── itoa
34+
│ ├── memchr
35+
│ ├── ryu
36+
│ └── serde_core
3237
├── spacetimedb_bindings_macro
3338
│ ├── heck
3439
│ ├── humantime
@@ -64,6 +69,7 @@ spacetimedb
6469
│ │ └── constant_time_eq
6570
│ │ [build-dependencies]
6671
│ │ └── cc
72+
│ │ ├── find_msvc_tools
6773
│ │ └── shlex
6874
│ ├── chrono
6975
│ │ └── num_traits
@@ -90,6 +96,7 @@ spacetimedb
9096
│ │ ├── enum_as_inner (*)
9197
│ │ ├── ethnum
9298
│ │ │ └── serde
99+
│ │ │ └── serde_core
93100
│ │ ├── hex
94101
│ │ ├── itertools (*)
95102
│ │ ├── second_stack

0 commit comments

Comments
 (0)