Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions tokenserver-db-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,18 @@ impl Clone for Box<dyn DbPool> {
}
}

#[cfg(debug_assertions)]
pub trait Db: BaseDb + TestDb {}
#[cfg(debug_assertions)]
impl<T: BaseDb + TestDb> Db for T {}

#[cfg(not(debug_assertions))]
pub trait Db: BaseDb {}
#[cfg(not(debug_assertions))]
impl<T: BaseDb> Db for T {}

#[async_trait(?Send)]
pub trait Db {
pub trait BaseDb {
/// Return the Db instance timeout duration.
fn timeout(&self) -> Option<Duration> {
None
Expand Down Expand Up @@ -279,50 +289,44 @@ pub trait Db {
created_at,
})
}
}

// Internal methods used by the db tests

#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
#[async_trait(?Send)]
/// Internal methods used by the db tests
pub trait TestDb {
async fn set_user_created_at(
&mut self,
params: params::SetUserCreatedAt,
) -> DbResult<results::SetUserCreatedAt>;

/// Update users replaced_at attribute based on user uid.
#[cfg(debug_assertions)]
async fn set_user_replaced_at(
&mut self,
params: params::SetUserReplacedAt,
) -> DbResult<results::SetUserReplacedAt>;

/// Get full user object based on passed user ID.
#[cfg(debug_assertions)]
async fn get_user(&mut self, params: params::GetUser) -> DbResult<results::GetUser>;

/// Create a complete node and return insert id from node.
#[cfg(debug_assertions)]
async fn post_node(&mut self, params: params::PostNode) -> DbResult<results::PostNode>;

/// Get complete node entry based on passed id.
#[cfg(debug_assertions)]
async fn get_node(&mut self, params: params::GetNode) -> DbResult<results::GetNode>;

/// Based on Node ID, unassign node from `users`.
#[cfg(debug_assertions)]
async fn unassign_node(
&mut self,
params: params::UnassignNode,
) -> DbResult<results::UnassignNode>;

/// Remove Node based on Node ID
#[cfg(debug_assertions)]
async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult<results::RemoveNode>;

#[cfg(debug_assertions)]
/// Creates new service and returns new service_id.
async fn post_service(&mut self, params: params::PostService)
-> DbResult<results::PostService>;

#[cfg(debug_assertions)]
fn set_spanner_node_id(&mut self, params: params::SpannerNodeId);
}
10 changes: 8 additions & 2 deletions tokenserver-db/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::sync::{Arc, LazyLock, Mutex};
use async_trait::async_trait;
use syncserver_common::Metrics;
use syncserver_db_common::GetPoolStatus;
use tokenserver_db_common::{Db, DbError, DbPool, params, results};
#[cfg(debug_assertions)]
use tokenserver_db_common::TestDb;
use tokenserver_db_common::{BaseDb, Db, DbError, DbPool, params, results};

#[derive(Clone, Default)]
pub struct CallLog {
Expand Down Expand Up @@ -70,7 +72,7 @@ impl MockDb {
}

#[async_trait(?Send)]
impl Db for MockDb {
impl BaseDb for MockDb {
async fn replace_user(
&mut self,
_params: params::ReplaceUser,
Expand Down Expand Up @@ -164,7 +166,11 @@ impl Db for MockDb {
static METRICS: LazyLock<Metrics> = LazyLock::new(Metrics::noop);
&METRICS
}
}

#[cfg(debug_assertions)]
#[async_trait(?Send)]
impl TestDb for MockDb {
#[cfg(debug_assertions)]
async fn set_user_created_at(
&mut self,
Expand Down
63 changes: 30 additions & 33 deletions tokenserver-mysql/src/db/db_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ use diesel::{
use diesel_async::RunQueryDsl;
use http::StatusCode;
use syncserver_common::Metrics;
use tokenserver_db_common::{Db, DbError, DbResult, params, results};
#[cfg(debug_assertions)]
use tokenserver_db_common::TestDb;
use tokenserver_db_common::{BaseDb, DbError, DbResult, params, results};

use super::TokenserverDb;

#[async_trait(?Send)]
impl Db for TokenserverDb {
impl BaseDb for TokenserverDb {
async fn get_node_id(&mut self, params: params::GetNodeId) -> DbResult<results::GetNodeId> {
const QUERY: &str = r#"
SELECT id
Expand Down Expand Up @@ -375,7 +377,32 @@ impl Db for TokenserverDb {
self.timeout
}

#[cfg(debug_assertions)]
async fn insert_sync15_node(&mut self, params: params::Sync15Node) -> DbResult<bool> {
let query = format!(
r#"
INSERT IGNORE INTO nodes (service, node, available, current_load, capacity, downed, backoff)
VALUES (
(SELECT id FROM services WHERE service = '{}'),
?, ?, 0, ?, 0, 0
)
"#,
params::Sync15Node::SERVICE_NAME
);

let affected_rows = diesel::sql_query(query)
.bind::<Text, _>(&params.node)
.bind::<Integer, _>(params.capacity)
.bind::<Integer, _>(params.capacity)
.execute(&mut self.conn)
.await?;

Ok(affected_rows == 1)
}
}

#[cfg(debug_assertions)]
#[async_trait(?Send)]
impl TestDb for TokenserverDb {
async fn set_user_created_at(
&mut self,
params: params::SetUserCreatedAt,
Expand All @@ -393,7 +420,6 @@ impl Db for TokenserverDb {
Ok(())
}

#[cfg(debug_assertions)]
async fn set_user_replaced_at(
&mut self,
params: params::SetUserReplacedAt,
Expand All @@ -411,7 +437,6 @@ impl Db for TokenserverDb {
Ok(())
}

#[cfg(debug_assertions)]
async fn get_user(&mut self, params: params::GetUser) -> DbResult<results::GetUser> {
const QUERY: &str = r#"
SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at
Expand All @@ -426,29 +451,6 @@ impl Db for TokenserverDb {
Ok(result)
}

async fn insert_sync15_node(&mut self, params: params::Sync15Node) -> DbResult<bool> {
let query = format!(
r#"
INSERT IGNORE INTO nodes (service, node, available, current_load, capacity, downed, backoff)
VALUES (
(SELECT id FROM services WHERE service = '{}'),
?, ?, 0, ?, 0, 0
)
"#,
params::Sync15Node::SERVICE_NAME
);

let affected_rows = diesel::sql_query(query)
.bind::<Text, _>(&params.node)
.bind::<Integer, _>(params.capacity)
.bind::<Integer, _>(params.capacity)
.execute(&mut self.conn)
.await?;

Ok(affected_rows == 1)
}

#[cfg(debug_assertions)]
async fn post_node(&mut self, params: params::PostNode) -> DbResult<results::PostNode> {
const QUERY: &str = r#"
INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff)
Expand All @@ -471,7 +473,6 @@ impl Db for TokenserverDb {
Ok(result)
}

#[cfg(debug_assertions)]
async fn get_node(&mut self, params: params::GetNode) -> DbResult<results::GetNode> {
const QUERY: &str = r#"
SELECT *
Expand All @@ -486,7 +487,6 @@ impl Db for TokenserverDb {
Ok(result)
}

#[cfg(debug_assertions)]
async fn unassign_node(
&mut self,
params: params::UnassignNode,
Expand All @@ -507,7 +507,6 @@ impl Db for TokenserverDb {
Ok(())
}

#[cfg(debug_assertions)]
async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult<results::RemoveNode> {
const QUERY: &str = "DELETE FROM nodes WHERE id = ?";

Expand All @@ -518,7 +517,6 @@ impl Db for TokenserverDb {
Ok(())
}

#[cfg(debug_assertions)]
async fn post_service(
&mut self,
params: params::PostService,
Expand All @@ -540,7 +538,6 @@ impl Db for TokenserverDb {
Ok(result)
}

#[cfg(debug_assertions)]
fn set_spanner_node_id(&mut self, params: params::SpannerNodeId) {
self.spanner_node_id = params;
}
Expand Down
Loading