Skip to content

Commit 2291b83

Browse files
audience added to token (#29)
* audience added to token
1 parent 28ce068 commit 2291b83

10 files changed

Lines changed: 106 additions & 15 deletions

File tree

crates/api-snowflake-rest-sessions/src/error.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use snafu::prelude::*;
1010
pub type Result<T> = std::result::Result<T, Error>;
1111

1212
#[derive(Snafu)]
13-
#[snafu(visibility(pub(crate)))]
13+
#[snafu(visibility(pub))]
1414
#[error_stack_trace::debug]
1515
pub enum Error {
1616
#[snafu(display("Can't add header to response: {error}"))]
@@ -42,6 +42,20 @@ pub enum Error {
4242
#[snafu(implicit)]
4343
location: Location,
4444
},
45+
46+
#[snafu(display("Can't authenticate request: Host is missing"))]
47+
MissingHost {
48+
#[snafu(implicit)]
49+
location: Location,
50+
},
51+
52+
#[snafu(display("Extension error: {error}"))]
53+
ExtensionRejection {
54+
#[snafu(source)]
55+
error: axum::extract::rejection::ExtensionRejection,
56+
#[snafu(implicit)]
57+
location: Location,
58+
},
4559
}
4660

4761
#[derive(Debug, Serialize, Deserialize)]
@@ -53,7 +67,12 @@ pub struct ErrorResponse {
5367
impl IntoResponse for Error {
5468
fn into_response(self) -> axum::response::Response<axum::body::Body> {
5569
let message = self.to_string();
56-
let code = StatusCode::INTERNAL_SERVER_ERROR;
70+
let code = match self {
71+
Self::BadAuthToken { .. }
72+
| Self::MissingHost { .. }
73+
| Self::ExtensionRejection { .. } => StatusCode::UNAUTHORIZED,
74+
_ => StatusCode::INTERNAL_SERVER_ERROR,
75+
};
5776

5877
let error = ErrorResponse {
5978
message,

crates/api-snowflake-rest-sessions/src/helpers.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@ use uuid::Uuid;
88
#[cfg_attr(test, derive(Debug))]
99
pub struct Claims {
1010
pub sub: String, // token issued to a particular user
11+
pub aud: String, // validate audience since as it can be deployed on multiple hosts
1112
pub iat: i64, // Issued At
1213
pub exp: i64, // Expiration Time
1314
pub session_id: String,
1415
}
1516

1617
#[must_use]
17-
pub fn jwt_claims(username: &str, expiration: Duration) -> Claims {
18+
pub fn jwt_claims(username: &str, audience: &str, expiration: Duration) -> Claims {
1819
let now = Local::now();
1920
let iat = now.timestamp();
2021
let exp = now.timestamp() + expiration.whole_seconds();
2122

2223
Claims {
2324
sub: username.to_string(),
25+
aud: audience.to_string(),
2426
iat,
2527
exp,
2628
session_id: Uuid::new_v4().to_string(),
@@ -29,11 +31,13 @@ pub fn jwt_claims(username: &str, expiration: Duration) -> Claims {
2931

3032
pub fn get_claims_validate_jwt_token(
3133
token: &str,
34+
audience: &str,
3235
jwt_secret: &str,
3336
) -> Result<Claims, jsonwebtoken::errors::Error> {
3437
let mut validation = Validation::default();
3538
validation.leeway = 5;
36-
validation.set_required_spec_claims(&["exp"]);
39+
validation.set_audience(&[audience]);
40+
validation.set_required_spec_claims(&["exp", "aud"]);
3741

3842
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());
3943

crates/api-snowflake-rest-sessions/src/layer.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
11
use crate::error as session_error;
2-
use crate::error::Result;
2+
use crate::error::{Error, Result};
33
use crate::session::{
44
DFSessionId, SESSION_ID_COOKIE_NAME, SessionStore, extract_token_from_cookie,
55
};
6-
use axum::extract::{Request, State};
6+
use axum::extract::{FromRequestParts, Request, State};
7+
use axum::http::{HeaderMap, HeaderName, request::Parts};
78
use axum::middleware::Next;
89
use axum::response::IntoResponse;
910
use http::header::SET_COOKIE;
10-
use http::{HeaderMap, HeaderName};
1111
use snafu::ResultExt;
1212
use tower_sessions::cookie::{Cookie, SameSite};
1313

14+
#[derive(Debug, Clone)]
15+
pub struct Host(pub String);
16+
17+
impl<S> FromRequestParts<S> for Host
18+
where
19+
S: Send + Sync,
20+
{
21+
type Rejection = Error;
22+
23+
#[allow(clippy::unwrap_used)]
24+
async fn from_request_parts(
25+
req: &mut Parts,
26+
state: &S,
27+
) -> std::result::Result<Self, Self::Rejection> {
28+
let headers = HeaderMap::from_request_parts(req, state)
29+
.await
30+
.map_err(|err| match err {})
31+
.unwrap(); // unwrap on Infallibe error is safe
32+
let host = headers.get("host");
33+
let host = host.and_then(|host| host.to_str().ok());
34+
if let Some(host) = host {
35+
Ok(Self(host.to_string()))
36+
} else {
37+
session_error::MissingHostSnafu.fail()
38+
}
39+
}
40+
}
41+
1442
#[allow(clippy::unwrap_used, clippy::cognitive_complexity)]
1543
pub async fn propagate_session_cookie(
1644
State(state): State<SessionStore>,

crates/api-snowflake-rest-sessions/src/session.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use http::header::COOKIE;
88
use http::request::Parts;
99
use http::{HeaderMap, HeaderName};
1010
use regex::Regex;
11-
use snafu::ResultExt;
11+
use snafu::{OptionExt, ResultExt};
1212
use std::{collections::HashMap, sync::Arc};
1313

1414
pub const SESSION_ID_COOKIE_NAME: &str = "session_id";
@@ -52,10 +52,24 @@ where
5252
async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
5353
let execution_svc = state.get_execution_svc();
5454

55+
// Not using Host extractor as for some reason it extracts host without port
56+
// use axum::RequestPartsExt;
57+
// use axum::extract::Extension;
58+
// use crate::layer::Host;
59+
// let Extension(Host(host)) = req.extract::<Extension<Host>>()
60+
// .await
61+
// .context(session_error::ExtensionRejectionSnafu)?;
62+
// tracing::info!("Host '{host}' extracted from DFSessionId");
63+
5564
let (session_id, located_at) = if let Some(token) = extract_token_from_auth(&req.headers) {
65+
// host is require to check token audience claim
66+
let host = req.headers.get("host");
67+
let host = host.and_then(|host| host.to_str().ok());
68+
let host = host.context(session_error::MissingHostSnafu)?;
69+
5670
let jwt_secret = state.jwt_secret();
57-
let jwt_claims =
58-
get_claims_validate_jwt_token(&token, jwt_secret).context(BadAuthTokenSnafu)?;
71+
let jwt_claims = get_claims_validate_jwt_token(&token, host, jwt_secret)
72+
.context(BadAuthTokenSnafu)?;
5973

6074
(jwt_claims.session_id, "auth header")
6175
} else {

crates/api-snowflake-rest/src/server/error.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ pub enum Error {
9393
#[snafu(implicit)]
9494
location: Location,
9595
},
96+
97+
#[snafu(display("Auth session error: {source}"))]
98+
AuthSession {
99+
#[snafu(source(from(api_snowflake_rest_sessions::error::Error, Box::new)))]
100+
source: Box<api_snowflake_rest_sessions::error::Error>,
101+
#[snafu(implicit)]
102+
location: Location,
103+
},
96104
}
97105

98106
impl IntoResponse for Error {
@@ -216,7 +224,8 @@ impl Error {
216224
| Self::InvalidAuthToken { .. }
217225
| Self::NoJwtSecret { .. }
218226
| Self::CreateJwt { .. }
219-
| Self::BadAuthToken { .. } => (
227+
| Self::BadAuthToken { .. }
228+
| Self::AuthSession { .. } => (
220229
http::StatusCode::UNAUTHORIZED,
221230
SqlState::Success,
222231
ErrorCode::Other,

crates/api-snowflake-rest/src/server/handlers.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::models::{
55
use crate::server::error::Result;
66
use crate::server::logic::{handle_login_request, handle_query_request};
77
use api_snowflake_rest_sessions::DFSessionId;
8+
use api_snowflake_rest_sessions::layer::Host;
89
use axum::Json;
910
use axum::extract::{ConnectInfo, Query, State};
1011
use executor::RunningQueryId;
@@ -23,10 +24,11 @@ pub struct SessionQueryParams {
2324

2425
#[tracing::instrument(name = "api_snowflake_rest::login", level = "debug", skip(state), err, ret(level = tracing::Level::TRACE))]
2526
pub async fn login(
27+
Host(host): Host,
2628
State(state): State<AppState>,
2729
Json(login_request): Json<LoginRequestBody>,
2830
) -> Result<Json<LoginResponse>> {
29-
let response = handle_login_request(&state, login_request.data).await?;
31+
let response = handle_login_request(&state, host, login_request.data).await?;
3032
Ok(Json(response))
3133
}
3234

crates/api-snowflake-rest/src/server/layer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::server::error::{BadAuthTokenSnafu, NoJwtSecretSnafu};
33
use api_snowflake_rest_sessions::helpers::{
44
ensure_jwt_secret_is_valid, get_claims_validate_jwt_token,
55
};
6+
use api_snowflake_rest_sessions::layer::Host;
67
use api_snowflake_rest_sessions::session::extract_token_from_auth;
78
use axum::extract::{Request, State};
89
use axum::middleware::Next;
@@ -19,6 +20,7 @@ use snafu::{OptionExt, ResultExt};
1920
)]
2021
pub async fn require_auth(
2122
State(state): State<AppState>,
23+
Host(host): Host,
2224
req: Request,
2325
next: Next,
2426
) -> error::Result<impl IntoResponse> {
@@ -35,7 +37,7 @@ pub async fn require_auth(
3537
ensure_jwt_secret_is_valid(&state.config.auth.jwt_secret).context(NoJwtSecretSnafu)?;
3638

3739
let jwt_claims =
38-
get_claims_validate_jwt_token(&token, &jwt_secret).context(BadAuthTokenSnafu)?;
40+
get_claims_validate_jwt_token(&token, &host, &jwt_secret).context(BadAuthTokenSnafu)?;
3941

4042
// Record the result as part of the current span.
4143
tracing::Span::current().record("session_id", jwt_claims.session_id.as_str());

crates/api-snowflake-rest/src/server/logic.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub const JWT_TOKEN_EXPIRATION_SECONDS: u32 = 24 * 60 * 60;
2424
)]
2525
pub async fn handle_login_request(
2626
state: &AppState,
27+
host: String,
2728
credentials: LoginRequestData,
2829
) -> Result<LoginResponse> {
2930
let LoginRequestData {
@@ -36,14 +37,18 @@ pub async fn handle_login_request(
3637
return api_snowflake_rest_error::InvalidAuthDataSnafu.fail();
3738
}
3839

40+
// host is required to check token audience claim
3941
let jwt_secret = &*state.config.auth.jwt_secret;
4042
let _ = ensure_jwt_secret_is_valid(jwt_secret).context(NoJwtSecretSnafu)?;
4143

4244
let jwt_claims = jwt_claims(
4345
&login_name,
46+
&host,
4447
Duration::seconds(JWT_TOKEN_EXPIRATION_SECONDS.into()),
4548
);
4649

50+
tracing::info!("Host '{host}' for token creation");
51+
4752
let session_id = jwt_claims.session_id.clone();
4853
let _ = state.execution_svc.create_session(&session_id).await?;
4954

crates/embucket-lambda/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ use api_snowflake_rest::server::layer::require_auth;
77
use api_snowflake_rest::server::router::{create_auth_router, create_router};
88
use api_snowflake_rest::server::server_models::Config as SnowflakeServerConfig;
99
use api_snowflake_rest::server::state::AppState;
10+
use api_snowflake_rest_sessions::layer::Host;
1011
use api_snowflake_rest_sessions::session::{SESSION_EXPIRATION_SECONDS, SessionStore};
12+
use axum::Extension;
1113
use axum::body::Body as AxumBody;
1214
use axum::extract::connect_info::ConnectInfo;
1315
use axum::{Router, middleware};
@@ -119,10 +121,12 @@ impl LambdaApp {
119121
let snowflake_router = create_router()
120122
.with_state(state.clone())
121123
.layer(compression_layer.clone())
124+
.layer(Extension(Host(String::default())))
122125
.layer(middleware::from_fn_with_state(state.clone(), require_auth));
123126
let snowflake_auth_router = create_auth_router()
124127
.with_state(state.clone())
125-
.layer(compression_layer);
128+
.layer(compression_layer)
129+
.layer(Extension(Host(String::default())));
126130
let router = Router::new().merge(snowflake_router.merge(snowflake_auth_router));
127131

128132
Ok(Self { router, state })

crates/embucketd/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use api_snowflake_rest::server::router::create_auth_router as create_snowflake_a
1212
use api_snowflake_rest::server::router::create_router as create_snowflake_router;
1313
use api_snowflake_rest::server::server_models::Config;
1414
use api_snowflake_rest::server::state::AppState as SnowflakeAppState;
15+
use api_snowflake_rest_sessions::layer::Host;
1516
use api_snowflake_rest_sessions::session::{SESSION_EXPIRATION_SECONDS, SessionStore};
17+
use axum::Extension;
1618
use axum::middleware;
1719
use axum::{
1820
Json, Router,
@@ -189,13 +191,15 @@ async fn async_main(
189191
let snowflake_router = create_snowflake_router()
190192
.with_state(snowflake_state.clone())
191193
.layer(compression_layer.clone())
194+
.layer(Extension(Host(String::default())))
192195
.layer(middleware::from_fn_with_state(
193196
snowflake_state.clone(),
194197
snowflake_require_auth,
195198
));
196199
let snowflake_auth_router = create_snowflake_auth_router()
197200
.with_state(snowflake_state.clone())
198-
.layer(compression_layer);
201+
.layer(compression_layer)
202+
.layer(Extension(Host(String::default())));
199203
let snowflake_router = snowflake_router.merge(snowflake_auth_router);
200204

201205
// --- OpenAPI specs ---

0 commit comments

Comments
 (0)