diff --git a/tokenserver-db-common/src/lib.rs b/tokenserver-db-common/src/lib.rs index 50cccc3bb4..f501f4a150 100644 --- a/tokenserver-db-common/src/lib.rs +++ b/tokenserver-db-common/src/lib.rs @@ -43,8 +43,18 @@ impl Clone for Box { } } +#[cfg(debug_assertions)] +pub trait Db: BaseDb + TestDb {} +#[cfg(debug_assertions)] +impl Db for T {} + +#[cfg(not(debug_assertions))] +pub trait Db: BaseDb {} +#[cfg(not(debug_assertions))] +impl Db for T {} + #[async_trait(?Send)] -pub trait Db { +pub trait BaseDb { /// Return the Db instance timeout duration. fn timeout(&self) -> Option { None @@ -279,50 +289,44 @@ pub trait Db { created_at, }) } +} - // Internal methods used by the db tests - - #[cfg(debug_assertions)] +#[cfg(debug_assertions)] +#[async_trait(?Send)] +/// Internal methods used by the db tests +pub trait TestDb { async fn set_user_created_at( &mut self, params: params::SetUserCreatedAt, ) -> DbResult; /// Update users replaced_at attribute based on user uid. - #[cfg(debug_assertions)] async fn set_user_replaced_at( &mut self, params: params::SetUserReplacedAt, ) -> DbResult; /// Get full user object based on passed user ID. - #[cfg(debug_assertions)] async fn get_user(&mut self, params: params::GetUser) -> DbResult; /// Create a complete node and return insert id from node. - #[cfg(debug_assertions)] async fn post_node(&mut self, params: params::PostNode) -> DbResult; /// Get complete node entry based on passed id. - #[cfg(debug_assertions)] async fn get_node(&mut self, params: params::GetNode) -> DbResult; /// Based on Node ID, unassign node from `users`. - #[cfg(debug_assertions)] async fn unassign_node( &mut self, params: params::UnassignNode, ) -> DbResult; /// Remove Node based on Node ID - #[cfg(debug_assertions)] async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult; - #[cfg(debug_assertions)] /// Creates new service and returns new service_id. async fn post_service(&mut self, params: params::PostService) -> DbResult; - #[cfg(debug_assertions)] fn set_spanner_node_id(&mut self, params: params::SpannerNodeId); } diff --git a/tokenserver-db/src/mock.rs b/tokenserver-db/src/mock.rs index 09dee954af..6696b4f1f1 100644 --- a/tokenserver-db/src/mock.rs +++ b/tokenserver-db/src/mock.rs @@ -5,7 +5,9 @@ use std::sync::{Arc, LazyLock, Mutex}; use async_trait::async_trait; use syncserver_common::Metrics; use syncserver_db_common::GetPoolStatus; -use tokenserver_db_common::{Db, DbError, DbPool, params, results}; +#[cfg(debug_assertions)] +use tokenserver_db_common::TestDb; +use tokenserver_db_common::{BaseDb, Db, DbError, DbPool, params, results}; #[derive(Clone, Default)] pub struct CallLog { @@ -70,7 +72,7 @@ impl MockDb { } #[async_trait(?Send)] -impl Db for MockDb { +impl BaseDb for MockDb { async fn replace_user( &mut self, _params: params::ReplaceUser, @@ -164,7 +166,11 @@ impl Db for MockDb { static METRICS: LazyLock = LazyLock::new(Metrics::noop); &METRICS } +} +#[cfg(debug_assertions)] +#[async_trait(?Send)] +impl TestDb for MockDb { #[cfg(debug_assertions)] async fn set_user_created_at( &mut self, diff --git a/tokenserver-mysql/src/db/db_impl.rs b/tokenserver-mysql/src/db/db_impl.rs index 56e169e990..a69f8971a7 100644 --- a/tokenserver-mysql/src/db/db_impl.rs +++ b/tokenserver-mysql/src/db/db_impl.rs @@ -9,12 +9,14 @@ use diesel::{ use diesel_async::RunQueryDsl; use http::StatusCode; use syncserver_common::Metrics; -use tokenserver_db_common::{Db, DbError, DbResult, params, results}; +#[cfg(debug_assertions)] +use tokenserver_db_common::TestDb; +use tokenserver_db_common::{BaseDb, DbError, DbResult, params, results}; use super::TokenserverDb; #[async_trait(?Send)] -impl Db for TokenserverDb { +impl BaseDb for TokenserverDb { async fn get_node_id(&mut self, params: params::GetNodeId) -> DbResult { const QUERY: &str = r#" SELECT id @@ -375,7 +377,32 @@ impl Db for TokenserverDb { self.timeout } - #[cfg(debug_assertions)] + async fn insert_sync15_node(&mut self, params: params::Sync15Node) -> DbResult { + let query = format!( + r#" + INSERT IGNORE INTO nodes (service, node, available, current_load, capacity, downed, backoff) + VALUES ( + (SELECT id FROM services WHERE service = '{}'), + ?, ?, 0, ?, 0, 0 + ) + "#, + params::Sync15Node::SERVICE_NAME + ); + + let affected_rows = diesel::sql_query(query) + .bind::(¶ms.node) + .bind::(params.capacity) + .bind::(params.capacity) + .execute(&mut self.conn) + .await?; + + Ok(affected_rows == 1) + } +} + +#[cfg(debug_assertions)] +#[async_trait(?Send)] +impl TestDb for TokenserverDb { async fn set_user_created_at( &mut self, params: params::SetUserCreatedAt, @@ -393,7 +420,6 @@ impl Db for TokenserverDb { Ok(()) } - #[cfg(debug_assertions)] async fn set_user_replaced_at( &mut self, params: params::SetUserReplacedAt, @@ -411,7 +437,6 @@ impl Db for TokenserverDb { Ok(()) } - #[cfg(debug_assertions)] async fn get_user(&mut self, params: params::GetUser) -> DbResult { const QUERY: &str = r#" SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at @@ -426,29 +451,6 @@ impl Db for TokenserverDb { Ok(result) } - async fn insert_sync15_node(&mut self, params: params::Sync15Node) -> DbResult { - let query = format!( - r#" - INSERT IGNORE INTO nodes (service, node, available, current_load, capacity, downed, backoff) - VALUES ( - (SELECT id FROM services WHERE service = '{}'), - ?, ?, 0, ?, 0, 0 - ) - "#, - params::Sync15Node::SERVICE_NAME - ); - - let affected_rows = diesel::sql_query(query) - .bind::(¶ms.node) - .bind::(params.capacity) - .bind::(params.capacity) - .execute(&mut self.conn) - .await?; - - Ok(affected_rows == 1) - } - - #[cfg(debug_assertions)] async fn post_node(&mut self, params: params::PostNode) -> DbResult { const QUERY: &str = r#" INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) @@ -471,7 +473,6 @@ impl Db for TokenserverDb { Ok(result) } - #[cfg(debug_assertions)] async fn get_node(&mut self, params: params::GetNode) -> DbResult { const QUERY: &str = r#" SELECT * @@ -486,7 +487,6 @@ impl Db for TokenserverDb { Ok(result) } - #[cfg(debug_assertions)] async fn unassign_node( &mut self, params: params::UnassignNode, @@ -507,7 +507,6 @@ impl Db for TokenserverDb { Ok(()) } - #[cfg(debug_assertions)] async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult { const QUERY: &str = "DELETE FROM nodes WHERE id = ?"; @@ -518,7 +517,6 @@ impl Db for TokenserverDb { Ok(()) } - #[cfg(debug_assertions)] async fn post_service( &mut self, params: params::PostService, @@ -540,7 +538,6 @@ impl Db for TokenserverDb { Ok(result) } - #[cfg(debug_assertions)] fn set_spanner_node_id(&mut self, params: params::SpannerNodeId) { self.spanner_node_id = params; } diff --git a/tokenserver-postgres/src/db/db_impl.rs b/tokenserver-postgres/src/db/db_impl.rs index a7b163cc47..2434f14e67 100644 --- a/tokenserver-postgres/src/db/db_impl.rs +++ b/tokenserver-postgres/src/db/db_impl.rs @@ -13,12 +13,14 @@ use diesel_async::RunQueryDsl; use http::StatusCode; use syncserver_common::Metrics; -use tokenserver_db_common::{Db, DbError, DbResult, params, results}; +#[cfg(debug_assertions)] +use tokenserver_db_common::TestDb; +use tokenserver_db_common::{BaseDb, DbError, DbResult, params, results}; use super::TokenserverPgDb; #[async_trait(?Send)] -impl Db for TokenserverPgDb { +impl BaseDb for TokenserverPgDb { // Services Table Methods /// Acquire service_id through passed in service string. @@ -43,27 +45,6 @@ impl Db for TokenserverPgDb { } } - // Create a new service, given a provided service string and pattern. - // Returns a service_id. - #[cfg(debug_assertions)] - async fn post_service( - &mut self, - params: params::PostService, - ) -> DbResult { - const INSERT_SERVICE_QUERY: &str = r#" - INSERT INTO services (service, pattern) - VALUES ($1, $2) - RETURNING id - "#; - - let result = diesel::sql_query(INSERT_SERVICE_QUERY) - .bind::(¶ms.service) - .bind::(¶ms.pattern) - .get_result::(&mut self.conn) - .await?; - Ok(result) - } - // Nodes Table Methods /// Upsert the initial node record for Sync 1.5. @@ -89,24 +70,6 @@ impl Db for TokenserverPgDb { Ok(affected_rows == 1) } - /// Get Node with complete metadata, given a provided Node ID. - /// Returns a complete Node, including id, service_id, node string identifier - /// availability, and current load. - #[cfg(debug_assertions)] - async fn get_node(&mut self, params: params::GetNode) -> DbResult { - const QUERY: &str = r#" - SELECT * - FROM nodes - WHERE id = $1 - "#; - - let result = diesel::sql_query(QUERY) - .bind::(params.id) - .get_result::(&mut self.conn) - .await?; - Ok(result) - } - /// Get the specific Node ID, given a provided service string and node. /// Returns a node_id. async fn get_node_id(&mut self, params: params::GetNodeId) -> DbResult { @@ -221,28 +184,6 @@ impl Db for TokenserverPgDb { } } - /// Create and Insert a new node. - /// Returns the last inserted `id` of the newly created node. - #[cfg(debug_assertions)] - async fn post_node(&mut self, params: params::PostNode) -> DbResult { - const QUERY: &str = r#" - INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) - VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id - "#; - let result = diesel::sql_query(QUERY) - .bind::(params.service_id) - .bind::(params.node) - .bind::(params.available) - .bind::(params.current_load) - .bind::(params.capacity) - .bind::(params.downed) - .bind::(params.backoff) - .get_result::(&mut self.conn) - .await?; - Ok(result) - } - /// Update the current load count of a node, passing in the service string and node string. /// This represents the addition of a user to a node, while not defining which user specifically. /// Does not return anything. @@ -283,37 +224,8 @@ impl Db for TokenserverPgDb { Ok(()) } - /// Remove a node given the node ID. - #[cfg(debug_assertions)] - async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult { - const QUERY: &str = "DELETE FROM nodes WHERE id = $1"; - - diesel::sql_query(QUERY) - .bind::(params.node_id) - .execute(&mut self.conn) - .await?; - Ok(()) - } - // Users Table Methods - /// Given a user id, return a single user (GetUser) struct. - /// Contains all data relevant to particular user. - #[cfg(debug_assertions)] - async fn get_user(&mut self, params: params::GetUser) -> DbResult { - const QUERY: &str = r#" - SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at - FROM users - WHERE uid = $1 - "#; - - let result = diesel::sql_query(QUERY) - .bind::(params.id) - .get_result::(&mut self.conn) - .await?; - Ok(result) - } - /// Given a service_id and email, return all matching users (up to 20). /// Returns vector of matching `GetUser` structs, a type alias for `GetRawUsers` async fn get_users(&mut self, params: params::GetUsers) -> DbResult { @@ -498,32 +410,28 @@ impl Db for TokenserverPgDb { Ok(()) } - /// Given ONLY a particular `node_id`, update the users table to indicate an unassigned - /// node by updating the `replaced_at` field with the current time since Unix Epoch. - #[cfg(debug_assertions)] - async fn unassign_node( - &mut self, - params: params::UnassignNode, - ) -> DbResult { - const QUERY: &str = r#" - UPDATE users - SET replaced_at = $1 - WHERE nodeid = $2 - "#; - - let current_time = Utc::now().timestamp_millis(); - - diesel::sql_query(QUERY) - .bind::(current_time) - .bind::(params.node_id) + /// Simple check function to ensure database liveliness. + async fn check(&mut self) -> DbResult { + diesel::sql_query("SELECT 1") .execute(&mut self.conn) .await?; - Ok(()) + Ok(true) + } + + fn timeout(&self) -> Option { + self.timeout + } + + fn metrics(&self) -> &Metrics { + &self.metrics } +} +#[cfg(debug_assertions)] +#[async_trait(?Send)] +impl TestDb for TokenserverPgDb { /// Given ONLY a particular `uid`, update the users table `created_at` value /// with the passed parameter. - #[cfg(debug_assertions)] async fn set_user_created_at( &mut self, params: params::SetUserCreatedAt, @@ -544,7 +452,6 @@ impl Db for TokenserverPgDb { /// Given ONLY a particular `uid`, update the users table `replaced_at` value /// with the passed parameter. - #[cfg(debug_assertions)] async fn set_user_replaced_at( &mut self, params: params::SetUserReplacedAt, @@ -563,24 +470,113 @@ impl Db for TokenserverPgDb { Ok(()) } - /// Simple check function to ensure database liveliness. - async fn check(&mut self) -> DbResult { - diesel::sql_query("SELECT 1") + /// Given a user id, return a single user (GetUser) struct. + /// Contains all data relevant to particular user. + async fn get_user(&mut self, params: params::GetUser) -> DbResult { + const QUERY: &str = r#" + SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at + FROM users + WHERE uid = $1 + "#; + + let result = diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&mut self.conn) + .await?; + Ok(result) + } + + /// Create and Insert a new node. + /// Returns the last inserted `id` of the newly created node. + async fn post_node(&mut self, params: params::PostNode) -> DbResult { + const QUERY: &str = r#" + INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id + "#; + let result = diesel::sql_query(QUERY) + .bind::(params.service_id) + .bind::(params.node) + .bind::(params.available) + .bind::(params.current_load) + .bind::(params.capacity) + .bind::(params.downed) + .bind::(params.backoff) + .get_result::(&mut self.conn) + .await?; + Ok(result) + } + + /// Get Node with complete metadata, given a provided Node ID. + /// Returns a complete Node, including id, service_id, node string identifier + /// availability, and current load. + async fn get_node(&mut self, params: params::GetNode) -> DbResult { + const QUERY: &str = r#" + SELECT * + FROM nodes + WHERE id = $1 + "#; + + let result = diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&mut self.conn) + .await?; + Ok(result) + } + + /// Given ONLY a particular `node_id`, update the users table to indicate an unassigned + /// node by updating the `replaced_at` field with the current time since Unix Epoch. + async fn unassign_node( + &mut self, + params: params::UnassignNode, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = $1 + WHERE nodeid = $2 + "#; + + let current_time = Utc::now().timestamp_millis(); + + diesel::sql_query(QUERY) + .bind::(current_time) + .bind::(params.node_id) .execute(&mut self.conn) .await?; - Ok(true) + Ok(()) } - fn timeout(&self) -> Option { - self.timeout + /// Remove a node given the node ID. + async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult { + const QUERY: &str = "DELETE FROM nodes WHERE id = $1"; + + diesel::sql_query(QUERY) + .bind::(params.node_id) + .execute(&mut self.conn) + .await?; + Ok(()) } - fn metrics(&self) -> &Metrics { - &self.metrics + /// Create a new service, given a provided service string and pattern. + /// Returns a service_id. + async fn post_service( + &mut self, + params: params::PostService, + ) -> DbResult { + const INSERT_SERVICE_QUERY: &str = r#" + INSERT INTO services (service, pattern) + VALUES ($1, $2) + RETURNING id + "#; + + let result = diesel::sql_query(INSERT_SERVICE_QUERY) + .bind::(¶ms.service) + .bind::(¶ms.pattern) + .get_result::(&mut self.conn) + .await?; + Ok(result) } - #[allow(dead_code)] - #[cfg(debug_assertions)] fn set_spanner_node_id(&mut self, params: params::SpannerNodeId) { self.spanner_node_id = params; }