Skip to content
Merged
4 changes: 4 additions & 0 deletions aggregation_mode/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions aggregation_mode/db/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ edition = "2021"

[dependencies]
tokio = { version = "1"}
# TODO: enable tls
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate" ] }
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate", "chrono" ] }


[[bin]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE tasks add COLUMN status_updated_at TIMESTAMPTZ DEFAULT now();
6 changes: 5 additions & 1 deletion aggregation_mode/db/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use sqlx::{
prelude::FromRow,
types::{BigDecimal, Uuid},
types::{
chrono::{DateTime, Utc},
BigDecimal, Uuid,
},
Type,
};

Expand All @@ -21,6 +24,7 @@ pub struct Task {
pub program_commitment: Vec<u8>,
pub merkle_path: Option<Vec<u8>>,
pub status: TaskStatus,
pub status_updated_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Clone, FromRow)]
Expand Down
44 changes: 38 additions & 6 deletions aggregation_mode/proof_aggregator/src/backend/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ impl Db {
Ok(Self { pool })
}

pub async fn get_pending_tasks_and_mark_them_as_processing(
/// Fetches tasks that are ready to be processed and atomically updates their status.
///
/// This function selects up to `limit` tasks for the given `proving_system_id` that are
/// either:
/// - in `pending` status, or
/// - in `processing` status but whose `status_updated_at` timestamp is older than 12 hours
/// (to recover tasks that may have been abandoned or stalled).
///
/// The selected rows are locked using `FOR UPDATE SKIP LOCKED` to ensure safe concurrent
/// processing by multiple workers. All selected tasks have their status set to
/// `processing` and their `status_updated_at` updated to `now()` before being returned.
pub async fn get_tasks_to_process_and_update_their_status(
&self,
proving_system_id: i32,
limit: i64,
Expand All @@ -32,12 +43,19 @@ impl Db {
"WITH selected AS (
SELECT task_id
FROM tasks
WHERE proving_system_id = $1 AND status = 'pending'
WHERE proving_system_id = $1
AND (
status = 'pending'
OR (
status = 'processing'
AND status_updated_at <= now() - interval '12 hours'
)
)
LIMIT $2
FOR UPDATE SKIP LOCKED
)
UPDATE tasks t
SET status = 'processing'
SET status = 'processing', status_updated_at = now()
FROM selected s
WHERE t.task_id = s.task_id
RETURNING t.*;",
Expand All @@ -61,7 +79,7 @@ impl Db {

for (task_id, merkle_path) in updates {
if let Err(e) = sqlx::query(
"UPDATE tasks SET merkle_path = $1, status = 'verified', proof = NULL WHERE task_id = $2",
"UPDATE tasks SET merkle_path = $1, status = 'verified', status_updated_at = now(), proof = NULL WHERE task_id = $2",
)
.bind(merkle_path)
.bind(task_id)
Expand All @@ -83,6 +101,20 @@ impl Db {
Ok(())
}

// TODO: this should be used when rolling back processing proofs on unexpected errors
pub async fn mark_tasks_as_pending(&self) {}
pub async fn mark_tasks_as_pending(&self, tasks_id: &[Uuid]) -> Result<(), DbError> {
if tasks_id.is_empty() {
return Ok(());
}

sqlx::query(
"UPDATE tasks SET status = 'pending', status_updated_at = now()
WHERE task_id = ANY($1) AND status = 'processing'",
)
.bind(tasks_id)
.execute(&self.pool)
.await
.map_err(|e| DbError::Query(e.to_string()))?;

Ok(())
}
}
2 changes: 1 addition & 1 deletion aggregation_mode/proof_aggregator/src/backend/fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl ProofsFetcher {
) -> Result<(Vec<AlignedProof>, Vec<Uuid>), ProofsFetcherError> {
let tasks = self
.db
.get_pending_tasks_and_mark_them_as_processing(engine.proving_system_id() as i32, limit)
.get_tasks_to_process_and_update_their_status(engine.proving_system_id() as i32, limit)
.await
.map_err(ProofsFetcherError::Query)?;

Expand Down
34 changes: 24 additions & 10 deletions aggregation_mode/proof_aggregator/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,28 +119,42 @@ impl ProofAggregator {
info!("Starting proof aggregator service");

info!("About to aggregate and submit proof to be verified on chain");
let res = self.aggregate_and_submit_proofs_on_chain().await;

let (proofs, tasks_id) = match self
.fetcher
.fetch_pending_proofs(self.engine.clone(), self.config.total_proofs_limit as i64)
.await
.map_err(AggregatedProofSubmissionError::FetchingProofs)
{
Ok(res) => res,
Err(e) => {
error!("Error while aggregating and submitting proofs: {:?}", e);
return;
}
};

let res = self
.aggregate_and_submit_proofs_on_chain((proofs, &tasks_id))
.await;

match res {
Ok(()) => {
info!("Process finished successfully");
}
Err(err) => {
error!("Error while aggregating and submitting proofs: {:?}", err);
warn!("Marking tasks back to pending after failure");
if let Err(e) = self.db.mark_tasks_as_pending(&tasks_id).await {
error!("Error while marking proofs to pending again: {:?}", e);
};
}
}
}

// TODO: on failure, mark proofs as pending again
async fn aggregate_and_submit_proofs_on_chain(
&mut self,
(proofs, tasks_id): (Vec<AlignedProof>, &[Uuid]),
) -> Result<(), AggregatedProofSubmissionError> {
let (proofs, tasks_id) = self
.fetcher
.fetch_pending_proofs(self.engine.clone(), self.config.total_proofs_limit as i64)
.await
.map_err(AggregatedProofSubmissionError::FetchingProofs)?;

if proofs.is_empty() {
warn!("No proofs collected, skipping aggregation...");
return Ok(());
Expand Down Expand Up @@ -215,7 +229,7 @@ impl ProofAggregator {

info!("Storing merkle paths for each task...",);
let mut merkle_paths_for_tasks: Vec<(Uuid, Vec<u8>)> = vec![];
for (idx, task_id) in tasks_id.into_iter().enumerate() {
for (idx, task_id) in tasks_id.iter().enumerate() {
let Some(proof) = merkle_tree.get_proof_by_pos(idx) else {
warn!("Proof not found for task id {task_id}");
continue;
Expand All @@ -226,7 +240,7 @@ impl ProofAggregator {
.flat_map(|e| e.to_vec())
.collect::<Vec<_>>();

merkle_paths_for_tasks.push((task_id, proof_bytes))
merkle_paths_for_tasks.push((*task_id, proof_bytes))
}
self.db
.insert_tasks_merkle_path_and_mark_them_as_verified(merkle_paths_for_tasks)
Expand Down
Loading