1- use axum:: { async_trait, response:: Response } ;
2- use axum_sessions:: {
3- async_session:: { MemoryStore , Session , SessionStore } ,
4- SessionLayer ,
5- } ;
6- use base64:: Engine ;
1+ use std:: { error:: Error , fmt:: Display } ;
2+
3+ use axum:: { async_trait, BoxError } ;
74use chrono:: Utc ;
85use futures:: future:: BoxFuture ;
96use http:: Request ;
107use intercode_entities:: sessions;
11- use sea_orm:: { sea_query:: OnConflict , ColumnTrait , EntityTrait , QueryFilter } ;
8+ use sea_orm:: { sea_query:: OnConflict , ColumnTrait , DbErr , EntityTrait , QueryFilter } ;
129use seawater:: ConnectionWrapper ;
1310use tower:: { Layer , Service } ;
11+ use tower_sessions:: {
12+ session:: SessionId , MemoryStore , Session , SessionManager , SessionManagerLayer , SessionRecord ,
13+ SessionStore ,
14+ } ;
1415use tracing:: log:: error;
1516
1617#[ derive( Clone , Debug ) ]
@@ -24,36 +25,54 @@ impl DbSessionStore {
2425 }
2526}
2627
27- #[ async_trait]
28- impl SessionStore for DbSessionStore {
29- async fn load_session (
30- & self ,
31- cookie_value : String ,
32- ) -> axum_sessions:: async_session:: Result < Option < Session > > {
33- let session_id = Session :: id_from_cookie_value ( & cookie_value) ?;
34- let engine = base64:: engine:: general_purpose:: STANDARD_NO_PAD ;
28+ #[ derive( Debug ) ]
29+ pub enum DbSessionError {
30+ DbErr ( DbErr ) ,
31+ SerializationError ( serde_json:: Error ) ,
32+ }
3533
36- sessions:: Entity :: find ( )
37- . filter ( sessions:: Column :: SessionId . eq ( session_id. clone ( ) ) )
38- . one ( self . db . as_ref ( ) )
39- . await
40- . map ( |find_result| {
41- find_result
42- . and_then ( |record| record. data )
43- . and_then ( |encoded| engine. decode ( encoded) . ok ( ) )
44- . and_then ( |bytes| String :: from_utf8 ( bytes) . ok ( ) )
45- . and_then ( |data| serde_json:: from_str :: < Session > ( & data) . ok ( ) )
46- } )
47- . map_err ( |err| err. into ( ) )
34+ impl Display for DbSessionError {
35+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
36+ match self {
37+ DbSessionError :: DbErr ( err) => err. fmt ( f) ,
38+ DbSessionError :: SerializationError ( err) => err. fmt ( f) ,
39+ }
40+ }
41+ }
42+
43+ impl Error for DbSessionError {
44+ fn source ( & self ) -> Option < & ( dyn Error + ' static ) > {
45+ None
46+ }
47+
48+ fn description ( & self ) -> & str {
49+ "description() is deprecated; use Display"
50+ }
51+
52+ fn cause ( & self ) -> Option < & dyn Error > {
53+ self . source ( )
54+ }
55+ }
56+
57+ impl From < DbErr > for DbSessionError {
58+ fn from ( value : DbErr ) -> Self {
59+ Self :: DbErr ( value)
60+ }
61+ }
62+
63+ impl From < serde_json:: Error > for DbSessionError {
64+ fn from ( value : serde_json:: Error ) -> Self {
65+ Self :: SerializationError ( value)
4866 }
67+ }
68+
69+ #[ async_trait]
70+ impl SessionStore for DbSessionStore {
71+ type Error = DbSessionError ;
4972
50- async fn store_session (
51- & self ,
52- session : Session ,
53- ) -> axum_sessions:: async_session:: Result < Option < String > > {
54- let engine = base64:: engine:: general_purpose:: STANDARD_NO_PAD ;
55- let session_id = session. id ( ) . to_string ( ) ;
56- let encoded_data = engine. encode ( serde_json:: to_string ( & session) ?) ;
73+ async fn save ( & self , session_record : & SessionRecord ) -> Result < ( ) , Self :: Error > {
74+ let session_id = session_record. id ( ) . to_string ( ) ;
75+ let encoded_data = serde_json:: to_string ( & session_record) ?;
5776 let model = sessions:: ActiveModel {
5877 id : sea_orm:: ActiveValue :: NotSet ,
5978 created_at : sea_orm:: ActiveValue :: Set ( Some ( Utc :: now ( ) . naive_utc ( ) ) ) ,
@@ -69,23 +88,28 @@ impl SessionStore for DbSessionStore {
6988 )
7089 . exec ( self . db . as_ref ( ) )
7190 . await ?;
72- Ok ( session . into_cookie_value ( ) )
91+ Ok ( ( ) )
7392 }
7493
75- async fn destroy_session (
76- & self ,
77- session : axum_sessions:: async_session:: Session ,
78- ) -> axum_sessions:: async_session:: Result {
79- sessions:: Entity :: delete_many ( )
80- . filter ( sessions:: Column :: SessionId . eq ( session. id ( ) ) )
81- . exec ( self . db . as_ref ( ) )
82- . await ?;
83-
84- Ok ( ( ) )
94+ async fn load ( & self , session_id : & SessionId ) -> Result < Option < Session > , Self :: Error > {
95+ sessions:: Entity :: find ( )
96+ . filter ( sessions:: Column :: SessionId . eq ( session_id. 0 . to_string ( ) ) )
97+ . one ( self . db . as_ref ( ) )
98+ . await
99+ . map ( |find_result| {
100+ find_result
101+ . and_then ( |record| record. data )
102+ // .and_then(|encoded| engine.decode(encoded).ok())
103+ // .and_then(|bytes| String::from_utf8(bytes).ok())
104+ . and_then ( |data| serde_json:: from_str :: < SessionRecord > ( & data) . ok ( ) )
105+ . and_then ( |rec| Some ( Session :: from ( rec) ) )
106+ } )
107+ . map_err ( DbSessionError :: from)
85108 }
86109
87- async fn clear_store ( & self ) -> axum_sessions :: async_session :: Result {
110+ async fn delete ( & self , session_id : & SessionId ) -> Result < ( ) , Self :: Error > {
88111 sessions:: Entity :: delete_many ( )
112+ . filter ( sessions:: Column :: SessionId . eq ( session_id. 0 . to_string ( ) ) )
89113 . exec ( self . db . as_ref ( ) )
90114 . await ?;
91115
@@ -94,70 +118,67 @@ impl SessionStore for DbSessionStore {
94118}
95119
96120#[ derive( Clone ) ]
97- pub struct SessionWithDbStoreFromTxLayer {
98- secret : [ u8 ; 64 ] ,
99- }
121+ pub struct SessionWithDbStoreFromTxLayer ;
100122
101123impl SessionWithDbStoreFromTxLayer {
102- pub fn new ( secret : [ u8 ; 64 ] ) -> Self {
103- Self { secret }
124+ pub fn new ( ) -> Self {
125+ Self { }
104126 }
105127}
106128
107129impl < S > Layer < S > for SessionWithDbStoreFromTxLayer {
108130 type Service = SessionWithDbStoreFromTxService < S > ;
109131
110132 fn layer ( & self , inner : S ) -> Self :: Service {
111- SessionWithDbStoreFromTxService {
112- secret : self . secret ,
113- inner,
114- }
133+ SessionWithDbStoreFromTxService { inner }
115134 }
116135}
117136
118137#[ derive( Clone ) ]
119138pub struct SessionWithDbStoreFromTxService < S > {
120- secret : [ u8 ; 64 ] ,
121139 inner : S ,
122140}
123141
124142impl < S , ReqBody , ResBody > Service < Request < ReqBody > > for SessionWithDbStoreFromTxService < S >
125143where
126- S : Service < Request < ReqBody > , Response = Response < ResBody > > + Clone + Send + ' static ,
144+ S : Service < Request < ReqBody > , Response = http :: Response < ResBody > > + Clone + Send + ' static ,
127145 ResBody : Send + ' static ,
128146 ReqBody : Send + ' static ,
129147 S :: Future : Send + ' static ,
148+ S :: Error : Error + Send + Sync ,
130149{
131- type Response = Response < ResBody > ;
132- type Error = S :: Error ;
150+ type Response = < SessionManager < S , DbSessionStore > as Service < Request < ReqBody > > > :: Response ;
151+ type Error = BoxError ;
133152 type Future = BoxFuture < ' static , Result < Self :: Response , Self :: Error > > ;
134153
135154 fn poll_ready (
136155 & mut self ,
137156 cx : & mut std:: task:: Context < ' _ > ,
138157 ) -> std:: task:: Poll < Result < ( ) , Self :: Error > > {
139- self . inner . poll_ready ( cx)
158+ self
159+ . inner
160+ . poll_ready ( cx)
161+ . map_err ( |err| Box :: new ( err) as BoxError )
140162 }
141163
142164 fn call ( & mut self , req : Request < ReqBody > ) -> Self :: Future {
143165 let inner = self . inner . clone ( ) ;
144- let secret = self . secret ;
145166 Box :: pin ( async move {
146167 let ( parts, body) = req. into_parts ( ) ;
147168 let db = parts. extensions . get :: < ConnectionWrapper > ( ) ;
148169
149170 match db {
150171 Some ( wrapper) => {
151172 let store = DbSessionStore :: new ( wrapper. clone ( ) ) ;
152- let layer = SessionLayer :: new ( store, & secret ) ;
173+ let layer = SessionManagerLayer :: new ( store) ;
153174 let mut service = layer. layer ( inner) ;
154175 let req = Request :: from_parts ( parts, body) ;
155176 service. call ( req) . await
156177 }
157178 None => {
158179 error ! ( "Couldn't get ConnectionWrapper from request extensions" ) ;
159- let store = MemoryStore :: new ( ) ;
160- let layer = SessionLayer :: new ( store, & secret ) ;
180+ let store = MemoryStore :: default ( ) ;
181+ let layer = SessionManagerLayer :: new ( store) ;
161182 let mut service = layer. layer ( inner) ;
162183 let req = Request :: from_parts ( parts, body) ;
163184 service. call ( req) . await
0 commit comments