Skip to content

Commit e1c8c7b

Browse files
authored
example update (#1333)
1 parent d74d6a0 commit e1c8c7b

12 files changed

Lines changed: 402 additions & 424 deletions

File tree

examples/auth/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ edition = "2024"
66
[dependencies]
77
dotenvy = "0.15.7"
88
password-auth = "1.0.0"
9-
serde = {version="1.0.228", features=["derive"]}
9+
serde = {version="1.0.228", features=["derive"] }
1010
serde_json = "1.0.148"
1111
validator = { version = "0.20.0", features = ["derive"] }
1212
xitca-postgres = "0.3.0"
13-
xitca-web = {version="0.7.1", features=["json","rate-limit", "compress-br", "compress-de", "compress-gz"]}
13+
xitca-web = { version="0.7.1", features=["json", "rate-limit", "compress-br", "compress-de", "compress-gz"]}
14+
thiserror = "2"
1415
tokio = { version = "1", features = ["full"] }
1516
cuid2 = "0.1.4"
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ impl DbClient {
1313
/// Pool automatically manages connection lifecycle and reuse.
1414
pub async fn new(database_url: &str) -> Result<Self, Error> {
1515
let pool = Pool::builder(database_url).capacity(10).build()?;
16-
Ok(Self {
17-
pool: Arc::new(pool),
18-
})
16+
Ok(Self { pool: Arc::new(pool) })
1917
}
2018

2119
/// Returns a reference to the underlying connection pool.

examples/auth/src/db/mod.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/auth/src/main.rs

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@ mod routes;
33
mod utils;
44

55
use crate::{
6-
db::db::DbClient,
6+
db::DbClient,
77
routes::{login::login, register::register},
88
utils::structs::ApiResponse,
99
};
1010
use dotenvy::dotenv;
1111
use xitca_web::{
1212
App, WebContext,
1313
error::Error,
14-
handler::handler_service,
15-
http::{Response, StatusCode, WebResponse},
16-
middleware::{compress::Compress, decompress::Decompress, rate_limit::RateLimit},
14+
handler::{Responder, handler_service, json::Json},
15+
http::WebResponse,
16+
middleware::{compress::Compress, rate_limit::RateLimit},
1717
route::post,
1818
service::Service,
1919
};
@@ -25,59 +25,30 @@ pub struct AppState {
2525
}
2626

2727
/// Global error handler middleware.
28-
/// Converts all errors to appropriate HTTP responses with correct status codes.
29-
async fn error_handler<S, C, B>(
30-
service: &S,
31-
mut ctx: WebContext<'_, C, B>,
32-
) -> Result<WebResponse, Error>
28+
async fn error_handler<S, C>(service: &S, mut ctx: WebContext<'_, C>) -> Result<WebResponse, Error>
3329
where
3430
C: 'static,
35-
B: 'static,
36-
S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Error>,
31+
S: for<'r> Service<WebContext<'r, C>, Response = WebResponse, Error = Error>,
3732
{
3833
match service.call(ctx.reborrow()).await {
3934
Ok(res) => Ok(res),
4035
Err(e) => {
41-
let error_msg = e.to_string();
36+
// print and convert error to http response
37+
eprintln!("{e}");
4238

43-
// Map error types to appropriate HTTP status codes.
44-
let status = if error_msg.contains("InvalidInput") || error_msg.contains("Validation") {
45-
StatusCode::BAD_REQUEST
46-
} else if error_msg.contains("NotFound") {
47-
StatusCode::NOT_FOUND
48-
} else if error_msg.contains("PermissionDenied")
49-
|| error_msg.contains("Invalid email or password")
50-
{
51-
StatusCode::UNAUTHORIZED
52-
} else if error_msg.contains("AlreadyExists")
53-
|| error_msg.contains("already registered")
54-
{
55-
StatusCode::CONFLICT
56-
} else {
57-
StatusCode::INTERNAL_SERVER_ERROR
58-
};
39+
let res = e.call(ctx.reborrow()).await?;
5940

60-
// If error message is already JSON, use it directly.
61-
// Otherwise, wrap in ApiResponse structure.
62-
let json_body = if error_msg.starts_with('{') && error_msg.ends_with('}') {
63-
error_msg
64-
} else {
65-
let error_response = ApiResponse::<()> {
41+
// override response body to json object.
42+
(
43+
res,
44+
Json(ApiResponse::<()> {
6645
success: false,
67-
message: error_msg,
46+
message: e.to_string(),
6847
data: None,
69-
};
70-
serde_json::to_string(&error_response).unwrap_or_else(|_| {
71-
r#"{"success":false,"message":"Internal error"}"#.to_string()
72-
})
73-
};
74-
75-
Ok(Response::builder()
76-
.status(status)
77-
.header("content-type", "application/json")
78-
.body(json_body.into())
79-
.unwrap()
80-
.into())
48+
}),
49+
)
50+
.respond(ctx)
51+
.await
8152
}
8253
}
8354
}
@@ -91,9 +62,7 @@ async fn main() -> std::io::Result<()> {
9162
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
9263

9364
// Initialize connection pool.
94-
let db_client = DbClient::new(&database_url)
95-
.await
96-
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
65+
let db_client = DbClient::new(&database_url).await.map_err(std::io::Error::other)?;
9766

9867
println!("✓ Server running on http://localhost:8080");
9968

@@ -107,7 +76,6 @@ async fn main() -> std::io::Result<()> {
10776
.enclosed_fn(error_handler) // Global error handling
10877
.enclosed(RateLimit::per_minute(60)) // Rate limiting: 60 requests/minute
10978
.enclosed(Compress) // Response compression
110-
.enclosed(Decompress) // Request decompression
11179
.serve()
11280
.bind("localhost:8080")?
11381
.run()

examples/auth/src/routes/login.rs

Lines changed: 24 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::{
22
AppState,
3-
utils::{error::DbError, structs::ApiResponse, validator::flatten_errors},
3+
utils::{
4+
error::{AuthError, DbError, ValidationError},
5+
structs::ApiResponse,
6+
},
47
};
58
use password_auth::verify_password;
69
use serde::{Deserialize, Serialize};
@@ -35,78 +38,37 @@ pub async fn login(
3538
Json(req): Json<Login>,
3639
) -> Result<Json<ApiResponse<UserData>>, Error> {
3740
// Validate input structure (email format, password length).
38-
// Using inspect_err for logging while preserving error handling.
39-
if let Err(e) = req
40-
.validate()
41-
.inspect_err(|e| eprintln!("Validation error: {:?}", e))
42-
{
43-
let error_body = flatten_errors(e);
44-
let response = ApiResponse {
45-
success: false,
46-
message: format!("Validation Failed: {:?}", error_body),
47-
data: None::<UserData>,
48-
};
49-
return Err(Error::from(std::io::Error::new(
50-
std::io::ErrorKind::InvalidInput,
51-
serde_json::to_string(&response).unwrap(),
52-
)));
53-
}
41+
req.validate().map_err(ValidationError)?;
5442

5543
let Login { email, password } = req;
5644

57-
// Get connection from pool - follows TechEmpower benchmark pattern.
58-
let conn = state.db_client.pool().get().await.map_err(DbError)?;
59-
6045
// Prepare statement on the connection (prevents SQL injection).
6146
let mut rows = Statement::named(
6247
"SELECT id, name, email, password FROM users WHERE email = $1",
6348
&[Type::TEXT],
6449
)
6550
.bind([&email])
66-
.query(&conn)
51+
.query(state.db_client.pool())
6752
.await
6853
.map_err(DbError)?;
6954

70-
if let Some(row) = rows.try_next().await.map_err(DbError)? {
71-
let user_id: String = row.get(0);
72-
let user_name: String = row.get(1);
73-
let user_email: String = row.get(2);
74-
let password_hash: String = row.get(3);
55+
let row = rows.try_next().await.map_err(DbError)?.ok_or(AuthError::NotFound)?;
56+
57+
let user_id: String = row.get(0);
58+
let user_name: String = row.get(1);
59+
let user_email: String = row.get(2);
60+
let password_hash: String = row.get(3);
61+
62+
// Verify password against stored hash using password_auth crate.
63+
verify_password(password, &password_hash).map_err(|_| AuthError::PermissionDenied)?;
7564

76-
// Verify password against stored hash using password_auth crate.
77-
match verify_password(password, &password_hash) {
78-
Ok(_) => Ok(Json(ApiResponse {
79-
success: true,
80-
message: "Login successful".to_string(),
81-
data: Some(UserData {
82-
id: user_id,
83-
name: user_name,
84-
email: user_email,
85-
}),
86-
})),
87-
Err(_) => {
88-
// Return generic error message to prevent user enumeration.
89-
let response = ApiResponse {
90-
success: false,
91-
message: "Invalid email or password".to_string(),
92-
data: None::<UserData>,
93-
};
94-
Err(Error::from(std::io::Error::new(
95-
std::io::ErrorKind::PermissionDenied,
96-
serde_json::to_string(&response).unwrap(),
97-
)))
98-
}
99-
}
100-
} else {
101-
// User not found - return same error message as invalid password.
102-
let response = ApiResponse {
103-
success: false,
104-
message: "Invalid email or password".to_string(),
105-
data: None::<UserData>,
106-
};
107-
Err(Error::from(std::io::Error::new(
108-
std::io::ErrorKind::NotFound,
109-
serde_json::to_string(&response).unwrap(),
110-
)))
111-
}
65+
Ok(Json(ApiResponse {
66+
success: true,
67+
message: "Login successful".to_string(),
68+
data: Some(UserData {
69+
id: user_id,
70+
name: user_name,
71+
email: user_email,
72+
}),
73+
}))
11274
}
Lines changed: 17 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::{
22
AppState,
3-
utils::{error::DbError, structs::ApiResponse, validator::flatten_errors},
3+
utils::{
4+
error::{DbError, ValidationError},
5+
structs::ApiResponse,
6+
},
47
};
58
use cuid2;
69
use password_auth::generate_hash;
@@ -36,86 +39,33 @@ pub async fn register(
3639
Json(req): Json<RegisterRequest>,
3740
) -> Result<Json<ApiResponse<RegisterResponse>>, Error> {
3841
// Validate input structure.
39-
// Using inspect_err for logging while preserving error handling.
40-
if let Err(e) = req
41-
.validate()
42-
.inspect_err(|e| eprintln!("Validation error: {:?}", e))
43-
{
44-
let error_body = flatten_errors(e);
45-
let response = ApiResponse {
46-
success: false,
47-
message: format!("Validation Failed: {:?}", error_body),
48-
data: None::<RegisterResponse>,
49-
};
50-
return Err(Error::from(std::io::Error::new(
51-
std::io::ErrorKind::InvalidInput,
52-
serde_json::to_string(&response).unwrap(),
53-
)));
54-
}
42+
req.validate().map_err(ValidationError)?;
5543

56-
let RegisterRequest {
57-
name,
58-
email,
59-
password,
60-
} = req;
44+
let RegisterRequest { name, email, password } = req;
6145

6246
// Generate collision-resistant ID using CUID2.
6347
let user_id = cuid2::create_id();
6448

6549
// Hash password using password_auth (Argon2 by default).
6650
let hashed_password = generate_hash(password);
6751

68-
// Get connection from pool.
69-
let conn = state.db_client.pool().get().await.map_err(DbError)?;
70-
7152
// Prepare INSERT statement.
7253
// Bind parameters and execute query.
73-
let res = Statement::named(
54+
let mut rows = Statement::named(
7455
"INSERT INTO users (id, name, email, password) VALUES ($1, $2, $3, $4) RETURNING id",
7556
&[Type::TEXT, Type::TEXT, Type::TEXT, Type::TEXT],
7657
)
7758
.bind([&user_id, &name, &email, &hashed_password])
78-
.query(&conn)
79-
.await;
80-
81-
match res {
82-
Ok(mut rows) => {
83-
let mut registered_id = String::new();
84-
if let Some(row) = rows.try_next().await.map_err(DbError)? {
85-
registered_id = row.get(0);
86-
}
87-
88-
Ok(Json(ApiResponse {
89-
success: true,
90-
message: "Registration successful".to_string(),
91-
data: Some(RegisterResponse {
92-
user_id: registered_id,
93-
}),
94-
}))
95-
}
96-
Err(e) => {
97-
let db_error = e.to_string();
98-
99-
// Check for unique constraint violation (duplicate email).
100-
let (message, kind) = if db_error.contains("unique constraint") {
101-
(
102-
"Email is already registered",
103-
std::io::ErrorKind::AlreadyExists,
104-
)
105-
} else {
106-
("Internal server error", std::io::ErrorKind::Other)
107-
};
59+
.query(state.db_client.pool())
60+
.await
61+
.map_err(DbError)?;
10862

109-
let response = ApiResponse {
110-
success: false,
111-
message: message.to_string(),
112-
data: None::<RegisterResponse>,
113-
};
63+
let row = rows.try_next().await.map_err(DbError)?.expect("sql must return id");
64+
let user_id = row.get("id");
11465

115-
Err(Error::from(std::io::Error::new(
116-
kind,
117-
serde_json::to_string(&response).unwrap(),
118-
)))
119-
}
120-
}
66+
Ok(Json(ApiResponse {
67+
success: true,
68+
message: "Registration successful".to_string(),
69+
data: Some(RegisterResponse { user_id }),
70+
}))
12171
}

0 commit comments

Comments
 (0)