Skip to content

Commit 7ee65cf

Browse files
committed
Port to the new tower-sessions crate
1 parent cefa4c8 commit 7ee65cf

8 files changed

Lines changed: 219 additions & 270 deletions

File tree

Cargo.lock

Lines changed: 108 additions & 183 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ aws-smithy-http = "*"
6060
axum = {version = "~0.6.20", features = ["multipart", "macros", "headers"]}
6161
axum-extra = {version = "~0.8.0", features = ["cookie"]}
6262
axum-server = {version = "0.5.1", features = ["tls-rustls"]}
63-
axum-sessions = "~0.5"
6463
base64 = "0.21.5"
6564
bcrypt = "0.15.0"
6665
bumpalo = "3.14.0"
@@ -70,7 +69,7 @@ chrono = "0.4"
7069
chrono-tz = "0.8"
7170
csrf = "0.4.1"
7271
dyn-clone = "1.0.14"
73-
futures = "0.3.28"
72+
futures = "0.3.29"
7473
hex = "0.4.3"
7574
html-escape = "0.2.13"
7675
http = "0.2.9"
@@ -129,10 +128,11 @@ serde_path_to_error = "*"
129128
sha1 = "0.10.6"
130129
stripe = {package = "async-stripe", version = "0.25.1", default-features = false, features = ["runtime-tokio-hyper-rustls", "checkout", "chrono", "connect", "webhook-events"]}
131130
syn = "2.0.38"
132-
time = "0.3.29"
131+
time = "0.3.30"
133132
tokio = {version = "~1.33.0", features = ["rt-multi-thread", "macros", "signal", "tracing"]}
134133
tower = {version = "0.4.13", features = ["limit", "make"]}
135134
tower-http = {version = "0.4.4", features = ["catch-panic", "compression-br", "compression-deflate", "compression-gzip", "trace"]}
135+
tower-sessions = "~0.3.3"
136136
tracing = "0.1.37"
137137
tuple-conv = "~1.0.1"
138138
twilio = "1.0.3"
@@ -146,15 +146,13 @@ tokio-console = ["dep:console-subscriber", "tokio/full", "tokio/tracing"]
146146

147147
[dependencies]
148148
async-graphql = {workspace = true}
149-
async-graphql-value = {workspace = true}
150149
axum = {workspace = true}
151150
chrono = {workspace = true}
152151
chrono-tz = {workspace = true}
153152
clap = {version = "4.4.7", features = ["derive", "cargo"]}
154153
console-subscriber = {version = "0.2.0", optional = true}
155154
dhat = {version = "0.3.2", optional = true}
156155
dotenv = "0.15.0"
157-
futures = {workspace = true}
158156
http = {workspace = true}
159157
indicatif = {workspace = true}
160158
intercode_cms = {workspace = true}
@@ -163,14 +161,12 @@ intercode_graphql = {workspace = true}
163161
intercode_graphql_core = {workspace = true}
164162
intercode_graphql_loaders = {workspace = true}
165163
intercode_graphql_presend = {workspace = true}
166-
intercode_liquid = {workspace = true}
167164
intercode_liquid_drops = {workspace = true}
168165
intercode_policies = {workspace = true}
169166
intercode_reporting = {workspace = true}
170167
intercode_server = {workspace = true}
171168
intercode_signups = {workspace = true}
172169
intercode_users = {workspace = true}
173-
itertools = {workspace = true}
174170
liquid = {workspace = true}
175171
once_cell = {workspace = true}
176172
opentelemetry = {workspace = true}
@@ -179,10 +175,7 @@ oxide-auth = {workspace = true}
179175
oxide-auth-axum = "~0.3.0"
180176
regex = {workspace = true}
181177
sea-orm = {workspace = true}
182-
seawater = {workspace = true}
183-
serde = {workspace = true}
184178
serde_json = {workspace = true}
185-
serde_path_to_error = {workspace = true}
186179
time = {workspace = true}
187180
tokio = {workspace = true}
188181
tonic = "~0.9.0"

crates/intercode_server/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ async-trait = {workspace = true}
1111
axum = {workspace = true}
1212
axum-extra = {workspace = true}
1313
axum-server = {workspace = true}
14-
axum-sessions = {workspace = true}
1514
base64 = {workspace = true}
1615
chrono = {workspace = true}
1716
chrono-tz = {workspace = true}
@@ -36,4 +35,5 @@ time = {workspace = true}
3635
tokio = {workspace = true}
3736
tower = {workspace = true}
3837
tower-http = {workspace = true}
38+
tower-sessions = {workspace = true}
3939
tracing = {workspace = true}

crates/intercode_server/src/app.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@ use std::env;
22
use std::sync::Arc;
33

44
use async_graphql::Result;
5+
use axum::error_handling::HandleErrorLayer;
56
use axum::extract::FromRef;
7+
use axum::BoxError;
68
use axum::{middleware::from_fn_with_state, routing::IntoMakeService, Extension, Router};
9+
use http::StatusCode;
710
use hyper::body::HttpBody;
811
use sea_orm::DatabaseConnection;
912
use tower::limit::ConcurrencyLimitLayer;
13+
use tower::ServiceBuilder;
1014
use tower_http::catch_panic::CatchPanicLayer;
1115
use tower_http::compression::CompressionLayer;
1216
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
@@ -33,15 +37,21 @@ where
3337
});
3438

3539
let csrf_config = CsrfConfig::new(&secret);
36-
let session_layer = SessionWithDbStoreFromTxLayer::new(secret);
40+
let session_layer = SessionWithDbStoreFromTxLayer::new();
3741

3842
let app: Router<S, B> = Router::new();
3943
let app = build_routes(app);
4044

45+
let session_service = ServiceBuilder::new()
46+
.layer(HandleErrorLayer::new(|err: BoxError| async move {
47+
(StatusCode::BAD_REQUEST, err.to_string())
48+
}))
49+
.layer(session_layer);
50+
4151
let app = app
4252
.layer(axum::middleware::from_fn(csrf_middleware))
4353
.layer(Extension(csrf_config))
44-
.layer(session_layer)
54+
.layer(session_service)
4555
.layer(from_fn_with_state(state.clone(), request_bound_transaction))
4656
.layer(ConcurrencyLimitLayer::new(
4757
env::var("MAX_CONCURRENCY")

crates/intercode_server/src/db_sessions.rs

Lines changed: 85 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
use axum::{async_trait, response::Response};
2-
use axum_sessions::{
3-
async_session::{MemoryStore, Session, SessionStore},
4-
SessionLayer,
5-
};
6-
use base64::Engine;
1+
use std::{error::Error, fmt::Display};
2+
3+
use axum::{async_trait, BoxError};
74
use chrono::Utc;
85
use futures::future::BoxFuture;
96
use http::Request;
107
use intercode_entities::sessions;
11-
use sea_orm::{sea_query::OnConflict, ColumnTrait, EntityTrait, QueryFilter};
8+
use sea_orm::{sea_query::OnConflict, ColumnTrait, DbErr, EntityTrait, QueryFilter};
129
use seawater::ConnectionWrapper;
1310
use tower::{Layer, Service};
11+
use tower_sessions::{
12+
session::SessionId, MemoryStore, Session, SessionManager, SessionManagerLayer, SessionRecord,
13+
SessionStore,
14+
};
1415
use tracing::log::error;
1516

1617
#[derive(Clone, Debug)]
@@ -24,36 +25,54 @@ impl DbSessionStore {
2425
}
2526
}
2627

27-
#[async_trait]
28-
impl SessionStore for DbSessionStore {
29-
async fn load_session(
30-
&self,
31-
cookie_value: String,
32-
) -> axum_sessions::async_session::Result<Option<Session>> {
33-
let session_id = Session::id_from_cookie_value(&cookie_value)?;
34-
let engine = base64::engine::general_purpose::STANDARD_NO_PAD;
28+
#[derive(Debug)]
29+
pub enum DbSessionError {
30+
DbErr(DbErr),
31+
SerializationError(serde_json::Error),
32+
}
3533

36-
sessions::Entity::find()
37-
.filter(sessions::Column::SessionId.eq(session_id.clone()))
38-
.one(self.db.as_ref())
39-
.await
40-
.map(|find_result| {
41-
find_result
42-
.and_then(|record| record.data)
43-
.and_then(|encoded| engine.decode(encoded).ok())
44-
.and_then(|bytes| String::from_utf8(bytes).ok())
45-
.and_then(|data| serde_json::from_str::<Session>(&data).ok())
46-
})
47-
.map_err(|err| err.into())
34+
impl Display for DbSessionError {
35+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36+
match self {
37+
DbSessionError::DbErr(err) => err.fmt(f),
38+
DbSessionError::SerializationError(err) => err.fmt(f),
39+
}
40+
}
41+
}
42+
43+
impl Error for DbSessionError {
44+
fn source(&self) -> Option<&(dyn Error + 'static)> {
45+
None
46+
}
47+
48+
fn description(&self) -> &str {
49+
"description() is deprecated; use Display"
50+
}
51+
52+
fn cause(&self) -> Option<&dyn Error> {
53+
self.source()
54+
}
55+
}
56+
57+
impl From<DbErr> for DbSessionError {
58+
fn from(value: DbErr) -> Self {
59+
Self::DbErr(value)
60+
}
61+
}
62+
63+
impl From<serde_json::Error> for DbSessionError {
64+
fn from(value: serde_json::Error) -> Self {
65+
Self::SerializationError(value)
4866
}
67+
}
68+
69+
#[async_trait]
70+
impl SessionStore for DbSessionStore {
71+
type Error = DbSessionError;
4972

50-
async fn store_session(
51-
&self,
52-
session: Session,
53-
) -> axum_sessions::async_session::Result<Option<String>> {
54-
let engine = base64::engine::general_purpose::STANDARD_NO_PAD;
55-
let session_id = session.id().to_string();
56-
let encoded_data = engine.encode(serde_json::to_string(&session)?);
73+
async fn save(&self, session_record: &SessionRecord) -> Result<(), Self::Error> {
74+
let session_id = session_record.id().to_string();
75+
let encoded_data = serde_json::to_string(&session_record)?;
5776
let model = sessions::ActiveModel {
5877
id: sea_orm::ActiveValue::NotSet,
5978
created_at: sea_orm::ActiveValue::Set(Some(Utc::now().naive_utc())),
@@ -69,23 +88,28 @@ impl SessionStore for DbSessionStore {
6988
)
7089
.exec(self.db.as_ref())
7190
.await?;
72-
Ok(session.into_cookie_value())
91+
Ok(())
7392
}
7493

75-
async fn destroy_session(
76-
&self,
77-
session: axum_sessions::async_session::Session,
78-
) -> axum_sessions::async_session::Result {
79-
sessions::Entity::delete_many()
80-
.filter(sessions::Column::SessionId.eq(session.id()))
81-
.exec(self.db.as_ref())
82-
.await?;
83-
84-
Ok(())
94+
async fn load(&self, session_id: &SessionId) -> Result<Option<Session>, Self::Error> {
95+
sessions::Entity::find()
96+
.filter(sessions::Column::SessionId.eq(session_id.0.to_string()))
97+
.one(self.db.as_ref())
98+
.await
99+
.map(|find_result| {
100+
find_result
101+
.and_then(|record| record.data)
102+
// .and_then(|encoded| engine.decode(encoded).ok())
103+
// .and_then(|bytes| String::from_utf8(bytes).ok())
104+
.and_then(|data| serde_json::from_str::<SessionRecord>(&data).ok())
105+
.and_then(|rec| Some(Session::from(rec)))
106+
})
107+
.map_err(DbSessionError::from)
85108
}
86109

87-
async fn clear_store(&self) -> axum_sessions::async_session::Result {
110+
async fn delete(&self, session_id: &SessionId) -> Result<(), Self::Error> {
88111
sessions::Entity::delete_many()
112+
.filter(sessions::Column::SessionId.eq(session_id.0.to_string()))
89113
.exec(self.db.as_ref())
90114
.await?;
91115

@@ -94,70 +118,67 @@ impl SessionStore for DbSessionStore {
94118
}
95119

96120
#[derive(Clone)]
97-
pub struct SessionWithDbStoreFromTxLayer {
98-
secret: [u8; 64],
99-
}
121+
pub struct SessionWithDbStoreFromTxLayer;
100122

101123
impl SessionWithDbStoreFromTxLayer {
102-
pub fn new(secret: [u8; 64]) -> Self {
103-
Self { secret }
124+
pub fn new() -> Self {
125+
Self {}
104126
}
105127
}
106128

107129
impl<S> Layer<S> for SessionWithDbStoreFromTxLayer {
108130
type Service = SessionWithDbStoreFromTxService<S>;
109131

110132
fn layer(&self, inner: S) -> Self::Service {
111-
SessionWithDbStoreFromTxService {
112-
secret: self.secret,
113-
inner,
114-
}
133+
SessionWithDbStoreFromTxService { inner }
115134
}
116135
}
117136

118137
#[derive(Clone)]
119138
pub struct SessionWithDbStoreFromTxService<S> {
120-
secret: [u8; 64],
121139
inner: S,
122140
}
123141

124142
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionWithDbStoreFromTxService<S>
125143
where
126-
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
144+
S: Service<Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
127145
ResBody: Send + 'static,
128146
ReqBody: Send + 'static,
129147
S::Future: Send + 'static,
148+
S::Error: Error + Send + Sync,
130149
{
131-
type Response = Response<ResBody>;
132-
type Error = S::Error;
150+
type Response = <SessionManager<S, DbSessionStore> as Service<Request<ReqBody>>>::Response;
151+
type Error = BoxError;
133152
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
134153

135154
fn poll_ready(
136155
&mut self,
137156
cx: &mut std::task::Context<'_>,
138157
) -> std::task::Poll<Result<(), Self::Error>> {
139-
self.inner.poll_ready(cx)
158+
self
159+
.inner
160+
.poll_ready(cx)
161+
.map_err(|err| Box::new(err) as BoxError)
140162
}
141163

142164
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
143165
let inner = self.inner.clone();
144-
let secret = self.secret;
145166
Box::pin(async move {
146167
let (parts, body) = req.into_parts();
147168
let db = parts.extensions.get::<ConnectionWrapper>();
148169

149170
match db {
150171
Some(wrapper) => {
151172
let store = DbSessionStore::new(wrapper.clone());
152-
let layer = SessionLayer::new(store, &secret);
173+
let layer = SessionManagerLayer::new(store);
153174
let mut service = layer.layer(inner);
154175
let req = Request::from_parts(parts, body);
155176
service.call(req).await
156177
}
157178
None => {
158179
error!("Couldn't get ConnectionWrapper from request extensions");
159-
let store = MemoryStore::new();
160-
let layer = SessionLayer::new(store, &secret);
180+
let store = MemoryStore::default();
181+
let layer = SessionManagerLayer::new(store);
161182
let mut service = layer.layer(inner);
162183
let req = Request::from_parts(parts, body);
163184
service.call(req).await

crates/intercode_server/src/middleware.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use axum::{
22
async_trait,
33
extract::{FromRequestParts, Host},
44
};
5-
use axum_sessions::SessionHandle;
65
use http::{request::Parts, StatusCode};
76
use intercode_entities::{
87
cms_parent::CmsParent, conventions, root_sites, user_con_profiles, users,
@@ -13,6 +12,7 @@ use once_cell::sync::Lazy;
1312
use regex::Regex;
1413
use sea_orm::{ColumnTrait, DbErr, EntityTrait, QueryFilter};
1514
use seawater::ConnectionWrapper;
15+
use tower_sessions::Session;
1616
use tracing::{error, warn};
1717

1818
static PORT_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(":\\d+$").unwrap());
@@ -78,8 +78,7 @@ impl<S: Sync> FromRequestParts<S> for QueryDataFromRequest {
7878
let (cms_parent, convention) = cms_parent_from_request_parts(parts, &db)
7979
.await
8080
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
81-
let session_handle = parts.extensions.get::<SessionHandle>().unwrap();
82-
let session = session_handle.read().await;
81+
let session = parts.extensions.get::<Session>().unwrap();
8382

8483
let Some(cms_parent) = cms_parent else {
8584
return Err((
@@ -93,7 +92,9 @@ impl<S: Sync> FromRequestParts<S> for QueryDataFromRequest {
9392
.get("X-Intercode-User-Timezone")
9493
.and_then(|header| header.to_str().ok());
9594

96-
let current_user_id: Option<i64> = session.get("current_user_id");
95+
let current_user_id: Option<i64> = session
96+
.get("current_user_id")
97+
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
9798
let current_user = if let Some(current_user_id) = current_user_id {
9899
users::Entity::find_by_id(current_user_id)
99100
.one(db.as_ref())

0 commit comments

Comments
 (0)