Skip to content

Commit 218c172

Browse files
MarcosNicolaumaximopalopoliJuArce
authored
feat(aggregation-mode): support multiple dbs (#2214)
Co-authored-by: Maximo Palopoli <96491141+maximopalopoli@users.noreply.github.com> Co-authored-by: JuArce <52429267+JuArce@users.noreply.github.com>
1 parent dcfb204 commit 218c172

File tree

23 files changed

+536
-188
lines changed

23 files changed

+536
-188
lines changed

aggregation_mode/Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

aggregation_mode/db/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ version = "0.1.0"
44
edition = "2021"
55

66
[dependencies]
7+
serde = { workspace = true }
78
tokio = { version = "1"}
8-
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate", "chrono" ] }
9-
9+
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate", "chrono", "uuid", "bigdecimal"] }
10+
tracing = { version = "0.1", features = ["log"] }
1011

1112
[[bin]]
1213
name = "migrate"

aggregation_mode/db/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
pub mod orchestrator;
2+
pub mod retry;
13
pub mod types;
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use std::{future::Future, sync::Arc, time::Duration};
2+
3+
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
4+
5+
use crate::retry::{next_backoff_delay, RetryConfig, RetryError};
6+
7+
/// A single DB node: connection pool plus shared health flags (used to prioritize nodes).
8+
#[derive(Debug)]
9+
struct DbNode {
10+
pool: Pool<Postgres>,
11+
}
12+
13+
/// Database orchestrator for running reads/writes across multiple PostgreSQL nodes with retry/backoff.
14+
///
15+
/// `DbOrchestrator` holds a list of database nodes (connection pools) and will
16+
/// retry transient failures with exponential backoff based on `retry_config`,
17+
///
18+
/// ## Thread-safe `Clone`
19+
/// This type is cheap and thread-safe to clone:
20+
/// - `nodes` is `Vec<Arc<DbNode>>`, so cloning only increments `Arc` ref-counts and shares the same pools/nodes,
21+
/// - `sqlx::Pool<Postgres>` is internally reference-counted and designed to be cloned and used concurrently,
22+
/// - the node health flags are `AtomicBool`, so updates are safe from multiple threads/tasks.
23+
///
24+
/// Clones share health state (the atomics) and the underlying pools, so all clones observe and influence
25+
/// the same “preferred node” ordering decisions.
26+
#[derive(Debug, Clone)]
27+
pub struct DbOrchestrator {
28+
nodes: Vec<Arc<DbNode>>,
29+
retry_config: RetryConfig,
30+
}
31+
32+
#[derive(Debug)]
33+
pub enum DbOrchestratorError {
34+
InvalidNumberOfConnectionUrls,
35+
Sqlx(sqlx::Error),
36+
}
37+
38+
impl std::fmt::Display for DbOrchestratorError {
39+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40+
match self {
41+
Self::InvalidNumberOfConnectionUrls => {
42+
write!(f, "invalid number of connection URLs")
43+
}
44+
Self::Sqlx(e) => write!(f, "{e}"),
45+
}
46+
}
47+
}
48+
49+
impl DbOrchestrator {
50+
pub fn try_new(
51+
connection_urls: &[String],
52+
retry_config: RetryConfig,
53+
) -> Result<Self, DbOrchestratorError> {
54+
if connection_urls.is_empty() {
55+
return Err(DbOrchestratorError::InvalidNumberOfConnectionUrls);
56+
}
57+
58+
let nodes = connection_urls
59+
.iter()
60+
.map(|url| {
61+
let pool = PgPoolOptions::new().max_connections(5).connect_lazy(url)?;
62+
63+
Ok(Arc::new(DbNode { pool }))
64+
})
65+
.collect::<Result<Vec<_>, sqlx::Error>>()
66+
.map_err(DbOrchestratorError::Sqlx)?;
67+
68+
Ok(Self {
69+
nodes,
70+
retry_config,
71+
})
72+
}
73+
74+
pub async fn query<T, Q, Fut>(&self, query_fn: Q) -> Result<T, sqlx::Error>
75+
where
76+
Q: Fn(Pool<Postgres>) -> Fut,
77+
Fut: Future<Output = Result<T, sqlx::Error>>,
78+
{
79+
let mut attempts = 0;
80+
let mut delay = Duration::from_millis(self.retry_config.min_delay_millis);
81+
82+
loop {
83+
match self.execute_once(&query_fn).await {
84+
Ok(value) => return Ok(value),
85+
Err(RetryError::Permanent(err)) => return Err(err),
86+
Err(RetryError::Transient(err)) => {
87+
if attempts >= self.retry_config.max_times {
88+
return Err(err);
89+
}
90+
91+
tracing::warn!(attempt = attempts, delay_millis = delay.as_millis(), error = ?err, "retrying after backoff");
92+
tokio::time::sleep(delay).await;
93+
delay = next_backoff_delay(delay, self.retry_config.clone());
94+
attempts += 1;
95+
}
96+
}
97+
}
98+
}
99+
100+
async fn execute_once<T, Q, Fut>(&self, query_fn: &Q) -> Result<T, RetryError<sqlx::Error>>
101+
where
102+
Q: Fn(Pool<Postgres>) -> Fut,
103+
Fut: Future<Output = Result<T, sqlx::Error>>,
104+
{
105+
let mut last_error = None;
106+
107+
for (idx, node) in self.nodes.iter().enumerate() {
108+
let pool = node.pool.clone();
109+
110+
match query_fn(pool).await {
111+
Ok(res) => {
112+
return Ok(res);
113+
}
114+
Err(err) => {
115+
if Self::is_connection_error(&err) {
116+
tracing::warn!(node_index = idx, error = ?err, "database query failed");
117+
last_error = Some(err);
118+
} else {
119+
return Err(RetryError::Permanent(err));
120+
}
121+
}
122+
};
123+
}
124+
125+
Err(RetryError::Transient(
126+
last_error.expect("write_op attempted without database nodes"),
127+
))
128+
}
129+
130+
fn is_connection_error(error: &sqlx::Error) -> bool {
131+
matches!(
132+
error,
133+
sqlx::Error::Io(_)
134+
| sqlx::Error::Tls(_)
135+
| sqlx::Error::Protocol(_)
136+
| sqlx::Error::PoolTimedOut
137+
| sqlx::Error::PoolClosed
138+
| sqlx::Error::WorkerCrashed
139+
| sqlx::Error::BeginFailed
140+
| sqlx::Error::Database(_)
141+
)
142+
}
143+
}

aggregation_mode/db/src/retry.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use std::time::Duration;
2+
3+
#[derive(Debug)]
4+
pub(super) enum RetryError<E> {
5+
Transient(E),
6+
Permanent(E),
7+
}
8+
9+
impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
10+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
11+
match self {
12+
RetryError::Transient(e) => write!(f, "{e}"),
13+
RetryError::Permanent(e) => write!(f, "{e}"),
14+
}
15+
}
16+
}
17+
18+
impl<E: std::fmt::Display> std::error::Error for RetryError<E> where E: std::fmt::Debug {}
19+
20+
#[derive(Debug, Clone)]
21+
pub struct RetryConfig {
22+
/// * `min_delay_millis` - Initial delay before first retry attempt (in milliseconds)
23+
pub min_delay_millis: u64,
24+
/// * `factor` - Exponential backoff multiplier for retry delays
25+
pub factor: f32,
26+
/// * `max_times` - Maximum number of retry attempts
27+
pub max_times: usize,
28+
/// * `max_delay_seconds` - Maximum delay between retry attempts (in seconds)
29+
pub max_delay_seconds: u64,
30+
}
31+
32+
// Exponential backoff with a hard cap.
33+
//
34+
// Each retry multiplies the previous delay by `retry_config.factor`,
35+
// then clamps it to `max_delay_seconds`. This yields:
36+
//
37+
// d_{n+1} = min(max, d_n * factor) => d_n = min(max, d_initial * factor^n)
38+
//
39+
// Example starting at 500ms with factor = 2.0 (no jitter):
40+
// retry 0: 0.5s
41+
// retry 1: 1.0s
42+
// retry 2: 2.0s
43+
// retry 3: 4.0s
44+
// retry 4: 8.0s
45+
// ...
46+
// until the delay reaches `max_delay_seconds`, after which it stays at that max.
47+
// see reference: https://en.wikipedia.org/wiki/Exponential_backoff
48+
// and here: https://docs.aws.amazon.com/prescriptive-guidance/latest/cloud-design-patterns/retry-backoff.html
49+
pub fn next_backoff_delay(current_delay: Duration, retry_config: RetryConfig) -> Duration {
50+
let max: Duration = Duration::from_secs(retry_config.max_delay_seconds);
51+
// Defensive: factor should be >= 1.0 for backoff, we clamp it to avoid shrinking/NaN.
52+
let factor = f64::from(retry_config.factor).max(1.0);
53+
54+
let scaled_secs = current_delay.as_secs_f64() * factor;
55+
let scaled_secs = if scaled_secs.is_finite() {
56+
scaled_secs
57+
} else {
58+
max.as_secs_f64()
59+
};
60+
61+
let scaled = Duration::from_secs_f64(scaled_secs);
62+
scaled.min(max)
63+
}

aggregation_mode/db/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use sqlx::{
77
Type,
88
};
99

10-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Type)]
10+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Type, serde::Serialize)]
1111
#[sqlx(type_name = "task_status", rename_all = "lowercase")]
1212
pub enum TaskStatus {
1313
Pending,

aggregation_mode/gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ serde_yaml = { workspace = true }
1010
agg_mode_sdk = { path = "../sdk"}
1111
aligned-sdk = { workspace = true }
1212
sp1-sdk = { workspace = true }
13+
db = { workspace = true }
1314
tracing = { version = "0.1", features = ["log"] }
1415
tracing-subscriber = { version = "0.3.0", features = ["env-filter"] }
1516
bincode = "1.3.3"

aggregation_mode/gateway/src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
66
pub struct Config {
77
pub ip: String,
88
pub port: u16,
9-
pub db_connection_url: String,
9+
pub db_connection_urls: Vec<String>,
1010
pub network: String,
1111
pub max_daily_proofs_per_user: i64,
1212
}

0 commit comments

Comments
 (0)