diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 788e4b22ef..5124a04326 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -21,7 +21,11 @@ use crate::read_only_transaction::{ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; use serde::{Deserialize, Serialize}; +use std::time::Duration; /// A builder for [BatchReadOnlyTransaction]. /// @@ -172,6 +176,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -229,6 +234,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -241,6 +247,8 @@ impl BatchReadOnlyTransaction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Partition { pub(crate) inner: PartitionedOperation, + #[serde(skip)] + pub(crate) gax_options: GaxRequestOptions, } impl Partition { @@ -270,6 +278,30 @@ impl Partition { self } + /// Sets the per-attempt timeout for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own attempt timeout. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own retry policy. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own backoff policy. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Executes this partition and returns a [ResultSet] that /// contains the rows that belong to this partition. /// @@ -320,18 +352,23 @@ impl Partition { /// the database that the partitions belong to. pub async fn execute(&self, client: &DatabaseClient) -> crate::Result { match &self.inner { - PartitionedOperation::Query(req) => Self::execute_query(client, req).await, - PartitionedOperation::Read(req) => Self::execute_read(client, req).await, + PartitionedOperation::Query(req) => { + Self::execute_query(client, req, self.gax_options.clone()).await + } + PartitionedOperation::Read(req) => { + Self::execute_read(client, req, self.gax_options.clone()).await + } } } async fn execute_query( client: &DatabaseClient, req: &crate::model::ExecuteSqlRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), gax_options.clone()) .send() .await?; @@ -345,16 +382,18 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), + gax_options, )) } async fn execute_read( client: &DatabaseClient, req: &crate::model::ReadRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), gax_options.clone()) .send() .await?; @@ -368,6 +407,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), + gax_options, )) } } @@ -401,6 +441,69 @@ pub(crate) mod tests { assert_impl_all!(Partition: Send, Sync, Debug); } + #[test] + fn serialize_partition_skips_gax_options() -> anyhow::Result<()> { + use std::time::Duration; + + let req = crate::model::ExecuteSqlRequest::new() + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let mut gax_options = GaxRequestOptions::default(); + gax_options.set_attempt_timeout(Duration::from_secs(5)); + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options, + }; + + let serialized = serde_json::to_string(&partition)?; + let deserialized: Partition = serde_json::from_str(&serialized)?; + + // Verify that gax_options was NOT preserved (it uses default, which is None timeout) + assert_eq!(*deserialized.gax_options.attempt_timeout(), None); + + Ok(()) + } + + #[tokio::test] + async fn partition_execute_respects_options() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use std::time::Duration; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let timeout = req.metadata().get("grpc-timeout"); + assert!(timeout.is_some(), "Missing grpc-timeout header"); + assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros + + let (_, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::from(rx)) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let req = crate::model::ExecuteSqlRequest::new() + .set_session("projects/p/instances/i/databases/d/sessions/123") + .set_transaction(crate::model::TransactionSelector { + selector: Some(Selector::Id(b"tx_id_1".to_vec().into())), + ..Default::default() + }) + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), + }; + + let partition = partition.with_attempt_timeout(Duration::from_secs(5)); + + let _result_set = partition.execute(&db_client).await?; + + Ok(()) + } + #[test] fn serialize_partition_query() -> anyhow::Result<()> { let req = crate::model::ExecuteSqlRequest::new() @@ -416,6 +519,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -448,6 +552,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -498,6 +603,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -540,6 +646,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -700,6 +807,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; @@ -732,6 +840,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index 5969681590..86390b2224 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -250,22 +250,31 @@ mod tests { use super::*; use crate::model::CreateSessionRequest; use crate::result_set::tests::adapt; + use gaxi::grpc::tonic::MetadataMap; use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; use google_cloud_gax::error::rpc::Code; + use google_cloud_gax::retry_state::RetryState; use google_cloud_test_macros::tokio_test_no_panics; use spanner_grpc_mock::google::rpc as mock_rpc; use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::google::spanner::v1::CommitResponse; + use spanner_grpc_mock::google::spanner::v1::ResultSet; + use spanner_grpc_mock::google::spanner::v1::ResultSetStats; use spanner_grpc_mock::google::spanner::v1::Session; + use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount; use spanner_grpc_mock::{MockSpanner, start}; use static_assertions::{assert_impl_all, assert_not_impl_any}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; mockall::mock! { #[derive(Debug)] BackoffPolicy {} - impl google_cloud_gax::backoff_policy::BackoffPolicy for BackoffPolicy { - fn on_failure(&self, state: &google_cloud_gax::retry_state::RetryState) -> std::time::Duration; + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; } } @@ -1089,6 +1098,275 @@ mod tests { Ok(()) } + fn parse_timeout(metadata: &MetadataMap) -> u64 { + let timeout = metadata + .get("grpc-timeout") + .expect("grpc-timeout header should be present"); + let timeout_str = timeout + .to_str() + .expect("grpc-timeout should be a valid string"); + if timeout_str.ends_with('u') { + timeout_str + .trim_end_matches('u') + .parse() + .expect("valid u64") + } else if timeout_str.ends_with('m') { + timeout_str + .trim_end_matches('m') + .parse::() + .expect("valid u64") + * 1000 + } else if timeout_str.ends_with('n') { + timeout_str + .trim_end_matches('n') + .parse::() + .expect("valid u64") + / 1000 + } else { + panic!("Unknown timeout unit in {}", timeout_str); + } + } + + #[tokio_test_no_panics] + async fn transaction_timeout_respected() -> anyhow::Result<()> { + use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; + use spanner_grpc_mock::google::spanner::v1::Transaction; + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + + mock.expect_commit().once().returning(|_| { + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + // Mock execute_sql to first fail and then succeed, checking timeout header on both + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 100000, + "Expected timeout to be <= 100ms, got {}", + timeout_val + ); + Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded")) + }); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 100000, + "Expected timeout to be <= 100ms, got {}", + timeout_val + ); + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + // 2. Initialize Client + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + // 3. Setup Transaction Runner with 100ms timeout + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(100)) + .build() + .await?; + + // 4. Run transaction and expect success after retry + let result = runner + .run(async |tx| { + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(1) + .returning(|_| Duration::from_nanos(1)); + + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + let stmt = Statement::builder("SELECT 1") + .with_retry_policy(retry_policy) + .with_backoff_policy(mock_backoff) + .build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + + #[tokio::test] + async fn transaction_timeout_ticks_down() -> anyhow::Result<()> { + use spanner_grpc_mock::google::spanner::v1::Transaction; + + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![1], + ..Default::default() + })) + }); + + let previous_timeout = Arc::new(AtomicU64::new(0)); + let prev_clone1 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 500000, + "Expected timeout to be <= 500ms, got {}", + timeout_val + ); + prev_clone1.store(timeout_val, Ordering::SeqCst); + Err(Status::new(GrpcCode::Aborted, "Aborted")) + }); + + // Second attempt: Checks that timeout is <= previous + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![2], + ..Default::default() + })) + }); + + let prev_clone2 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone2.load(Ordering::SeqCst); + assert!( + timeout_val <= prev, + "Timeout should tick down between attempts or be equal, got {} and {}", + timeout_val, + prev + ); + prev_clone2.store(timeout_val, Ordering::SeqCst); // store for next check + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + let prev_clone3 = previous_timeout.clone(); + mock.expect_commit().once().returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone3.load(Ordering::SeqCst); + assert!( + timeout_val < prev, + "Timeout should be smaller for commit, got {} and {}", + timeout_val, + prev + ); + + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(500)) + .build() + .await?; + + let result = runner + .run(async |tx| { + let stmt = Statement::builder("SELECT 1").build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + #[test] fn test_parse_emulator_endpoint() { assert_eq!( diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index a7f3a1356d..04e877e90f 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -610,6 +610,7 @@ macro_rules! execute_stream_with_retry { $self.client.clone(), $self.session_name.clone(), $operation_variant($request), + $gax_options, )) }}; } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 3e7bc9105d..4eb6b73e77 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -33,8 +33,15 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::ReadContext; use crate::result_set::ResultSet; use crate::statement::Statement; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicy; +use google_cloud_gax::retry_result::RetryResult; +use google_cloud_gax::retry_state::RetryState; use std::sync::Arc; use std::sync::atomic::{AtomicI64, Ordering}; +use std::time::Duration as StdDuration; +use tokio::time::Instant; use wkt::Duration; /// A builder for [ReadWriteTransaction]. @@ -99,7 +106,10 @@ impl ReadWriteTransactionBuilder { self } - pub(crate) async fn begin_transaction(&self) -> crate::Result { + pub(crate) async fn begin_transaction( + &self, + deadline: Option, + ) -> crate::Result { let session_name = self.session_name.clone(); let mut request = BeginTransactionRequest::default() .set_session(session_name.clone()) @@ -111,10 +121,16 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let mut options = RequestOptions::default(); + if let Some(d) = deadline { + let remaining = d.saturating_duration_since(Instant::now()); + options.set_attempt_timeout(remaining); + } + let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, options) .await?; let transaction_selector = @@ -132,6 +148,7 @@ impl ReadWriteTransactionBuilder { }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, + deadline, }) } } @@ -142,6 +159,7 @@ pub struct ReadWriteTransaction { pub(crate) context: ReadContext, seqno: Arc, max_commit_delay: Option, + pub(crate) deadline: Option, } impl ReadWriteTransaction { @@ -150,7 +168,14 @@ impl ReadWriteTransaction { &self, statement: T, ) -> crate::Result { - self.context.execute_query(statement).await + if self.deadline.is_none() { + return self.context.execute_query(statement).await; + } + let stmt = statement.into(); + let mut gax_options = stmt.gax_options().clone(); + self.apply_transaction_timeout(&mut gax_options); + let stmt = stmt.with_gax_options(gax_options); + self.context.execute_query(stmt).await } /// Reads rows from the database using key lookups and scans, as a simple key/value style alternative to execute_query. @@ -158,14 +183,22 @@ impl ReadWriteTransaction { &self, read: T, ) -> crate::Result { - self.context.execute_read(read).await + if self.deadline.is_none() { + return self.context.execute_read(read).await; + } + let mut req = read.into(); + self.apply_transaction_timeout(&mut req.gax_options); + self.context.execute_read(req).await } /// Executes an update using this transaction. pub async fn execute_update>(&self, statement: T) -> crate::Result { - let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statement = statement.into(); - let gax_options = statement.gax_options().clone(); + let mut gax_options = statement.gax_options().clone(); + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut gax_options); + } + let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let mut request = statement .into_request() .set_session(self.context.session_name.clone()) @@ -260,6 +293,10 @@ impl ReadWriteTransaction { /// # } /// ``` pub async fn execute_batch_update(&self, batch: BatchDml) -> crate::Result> { + let mut batch = batch; + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut batch.gax_options); + } let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statements: Vec = batch @@ -313,11 +350,15 @@ impl ReadWriteTransaction { .set_or_clear_request_options(self.context.amend_request_options(None)) .set_or_clear_max_commit_delay(self.max_commit_delay); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + let response = self .context .client .spanner - .commit(request, RequestOptions::default()) + .commit(request, gax_options) .await?; let response = @@ -328,10 +369,14 @@ impl ReadWriteTransaction { .set_precommit_token(*new_precommit_token) .set_or_clear_request_options(self.context.amend_request_options(None)); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit(retry_commit_req, gax_options) .await? } else { response @@ -359,6 +404,40 @@ impl ReadWriteTransaction { Ok(()) } + + fn apply_transaction_timeout(&self, options: &mut GaxRequestOptions) { + if let Some(deadline) = self.deadline { + let inner_policy = options + .retry_policy() + .clone() + .unwrap_or_else(|| Arc::new(google_cloud_gax::retry_policy::Aip194Strict)); + let bounded_policy = TransactionBoundedRetryPolicy { + inner: inner_policy, + deadline, + }; + options.set_retry_policy(bounded_policy); + } + } +} + +/// A retry policy that wraps another policy and bounds the total execution time +/// by a specific transaction deadline. +/// +/// This policy delegates `on_error` to the inner policy but overrides `remaining_time` +/// to ensure that it never exceeds the time left until the transaction deadline. +#[derive(Debug)] +struct TransactionBoundedRetryPolicy { + inner: Arc, + deadline: Instant, +} + +impl RetryPolicy for TransactionBoundedRetryPolicy { + fn on_error(&self, state: &RetryState, error: GaxError) -> RetryResult { + self.inner.on_error(state, error) + } + fn remaining_time(&self, _state: &RetryState) -> Option { + Some(self.deadline.saturating_duration_since(Instant::now())) + } } #[cfg(test)] @@ -460,7 +539,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -527,7 +606,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); let count = tx @@ -564,7 +643,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -609,7 +688,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -668,7 +747,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -713,7 +792,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -771,7 +850,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -821,7 +900,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -873,7 +952,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_isolation_level(IsolationLevel::Serializable) .with_read_lock_mode(ReadLockMode::Pessimistic) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -897,7 +976,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_exclude_txn_from_change_streams(true) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -957,7 +1036,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1031,7 +1110,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1072,7 +1151,7 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_max_commit_delay(Duration::new(0, 200_000_000).unwrap()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 8832fb6209..4564c1c40e 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -23,9 +23,16 @@ use crate::row::Row; use crate::server_streaming::stream::PartialResultSetStream; use bytes::Bytes; use gaxi::prost::FromProto; -use google_cloud_gax::error::rpc::Code; +use google_cloud_gax::backoff_policy::BackoffPolicy; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; +use google_cloud_gax::retry_state::RetryState; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use tokio::time::sleep; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -64,6 +71,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + gax_options: GaxRequestOptions, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -95,7 +103,9 @@ impl ResultSet { client: DatabaseClient, session_name: String, operation: StreamOperation, + gax_options: GaxRequestOptions, ) -> Self { + let gax_options = Self::apply_defaults(gax_options); Self { stream: Some(stream), buffered_values: Vec::new(), @@ -114,9 +124,24 @@ impl ResultSet { retry_count: 0, transaction_selector, stats: None, + gax_options, } } + fn apply_defaults(mut gax_options: GaxRequestOptions) -> GaxRequestOptions { + if gax_options.retry_policy().is_none() { + gax_options.set_retry_policy(Aip194Strict.with_attempt_limit(10)); + } + if gax_options.backoff_policy().is_none() { + gax_options.set_backoff_policy(Self::default_backoff_policy()); + } + gax_options + } + + fn default_backoff_policy() -> Arc { + Arc::new(ExponentialBackoffBuilder::default().clamp()) + } + /// Returns the metadata of the result set. /// /// # Example @@ -290,6 +315,15 @@ impl ResultSet { // Clear the buffer and restart the stream using the last // resume_token that we have seen. self.partial_result_sets_buffer.clear(); + + // Apply backoff delay if policy is present + if let Some(policy) = self.gax_options.backoff_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + let delay = policy.on_failure(&state); + sleep(delay).await; + } + self.restart_stream().await?; return Ok(()); } @@ -461,7 +495,7 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -474,7 +508,7 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -483,13 +517,17 @@ impl ResultSet { Ok(()) } - // TODO(#5185): Make the retry policy configurable. fn should_retry(&self, e: &crate::Error) -> bool { - if self.retry_count >= 10 { - return false; + if let Some(policy) = self.gax_options.retry_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + + if let Some(status) = e.status() { + let gax_error = GaxError::service(status.clone()); + return policy.on_error(&state, gax_error).is_continue(); + } } - e.status() - .is_some_and(|status| status.code == Code::Unavailable) + false } /// Converts the [`ResultSet`] into a [`Stream`]. @@ -586,8 +624,14 @@ impl ResultSet { pub(crate) mod tests { use super::*; use crate::client::Spanner; - use gaxi::grpc::tonic::Response; + use crate::client::Statement; + use crate::key::KeySet; + use crate::read::ReadRequest; + use gaxi::grpc::tonic::{Code as GrpcCode, Response}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_state::RetryState; + use google_cloud_test_macros::tokio_test_no_panics; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1 as spanner_v1; @@ -596,6 +640,15 @@ pub(crate) mod tests { PartialResultSet, ResultSetMetadata, Session, StructType, }; use spanner_grpc_mock::start; + use std::time::Duration; + + mockall::mock! { + #[derive(Debug)] + BackoffPolicy {} + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; + } + } pub(crate) fn string_val(s: &str) -> Value { Value { @@ -863,6 +916,27 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_default_policies_applied() -> anyhow::Result<()> { + let rs = run_mock_query(vec![PartialResultSet { + metadata: metadata(2), + last: true, + ..Default::default() + }]) + .await; + + assert!( + rs.gax_options.retry_policy().is_some(), + "Default retry policy should be applied" + ); + assert!( + rs.gax_options.backoff_policy().is_some(), + "Default backoff policy should be applied" + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_retry_read_stream() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; @@ -920,8 +994,14 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let read_req = crate::read::ReadRequest::builder("table", vec!["Id", "Value"]) .with_keys(crate::key::KeySet::all()) + .with_backoff_policy(mock_backoff) .build(); let mut rs: ResultSet = tx.execute_read(read_req).await?; @@ -936,6 +1016,99 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_custom_retry_policy() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + use google_cloud_gax::retry_policy::Aip194Strict; + use google_cloud_gax::retry_policy::RetryPolicyExt; + + // Extend the default retry policy to also retry on ResourceExhausted. + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Fail with RESOURCE_EXHAUSTED on first call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([ + Ok(PartialResultSet { + metadata: metadata(2), + values: vec![string_val("row1"), string_val("b")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::new(GrpcCode::ResourceExhausted, "Quota exceeded")), + ]); + Ok(Response::from(stream)) + }); + + // Succeed on second call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([Ok(PartialResultSet { + values: vec![string_val("row2"), string_val("d")], + resume_token: b"token2".to_vec(), + last: true, + ..Default::default() + })]); + Ok(Response::from(stream)) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(1) + .returning(|_| Duration::from_nanos(1)); + + let read_req = ReadRequest::builder("table", vec!["Id", "Value"]) + .with_keys(KeySet::all()) + .with_retry_policy(retry_policy) + .with_backoff_policy(mock_backoff) + .build(); + + let mut rs: ResultSet = tx.execute_read(read_req).await?; + + let row1 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row1.raw_values()[0].0, string_val("row1")); + + // This next() call should trigger the retry because the previous stream ended with error! + let row2 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row2.raw_values()[0].0, string_val("row2")); + + assert!(rs.next().await.is_none()); + + Ok(()) + } + #[tokio::test] async fn test_result_set_one_row() { let mut rs = run_mock_query(vec![PartialResultSet { @@ -1364,7 +1537,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Stream ended unexpectedly")?; assert_eq!(row1.raw_values()[0].0, string_val("row1")); @@ -1571,7 +1752,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Expected row1")?; assert_eq!(row1.raw_values()[0].0, string_val("row1_retry")); @@ -1579,7 +1768,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn test_result_set_retry_under_limit_no_resume_token() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1646,7 +1835,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; // Set max buffer size to 3 (so 2 messages is under the limit) rs.set_max_buffered_partial_result_sets(3); @@ -1695,7 +1892,16 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(10) + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let res = rs.next().await; assert!(res.is_some(), "Expected an error but got None"); @@ -1711,7 +1917,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1806,7 +2012,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1891,7 +2097,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; diff --git a/src/spanner/src/statement.rs b/src/spanner/src/statement.rs index fa12ad459c..90be21d7fe 100644 --- a/src/spanner/src/statement.rs +++ b/src/spanner/src/statement.rs @@ -233,6 +233,12 @@ impl Statement { &self.gax_options } + /// Returns a new `Statement` with the given `GaxRequestOptions`. + pub(crate) fn with_gax_options(mut self, options: GaxRequestOptions) -> Self { + self.gax_options = options; + self + } + /// Sets the query mode to use for this statement. /// /// # Example diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index b96b0da549..820cc5ee37 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -19,6 +19,7 @@ use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBu use crate::transaction_retry_policy::{ BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted, }; +use std::time::Duration as StdDuration; use wkt::Duration; /// A builder for a [TransactionRunner] for a read/write transaction. @@ -45,6 +46,7 @@ use wkt::Duration; pub struct TransactionRunnerBuilder { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunnerBuilder { @@ -52,9 +54,35 @@ impl TransactionRunnerBuilder { Self { builder: ReadWriteTransactionBuilder::new(client), retry_policy: Box::new(BasicTransactionRetryPolicy::default()), + timeout: None, } } + /// Sets the timeout for the entire transaction. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # use std::time::Duration; + /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let runner = db_client.read_write_transaction() + /// .with_transaction_timeout(Duration::from_secs(5)) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This timeout applies to the total time spent executing the transaction, including + /// all statements and automatic retries. Each individual RPC within the transaction + /// is automatically assigned a deadline derived from the remaining time of this + /// overall timeout. + pub fn with_transaction_timeout(mut self, timeout: StdDuration) -> Self { + self.timeout = Some(timeout); + self + } + /// Sets the isolation level for the transaction. /// /// # Example @@ -228,6 +256,7 @@ impl TransactionRunnerBuilder { Ok(TransactionRunner { builder: self.builder, retry_policy: self.retry_policy, + timeout: self.timeout, }) } } @@ -236,6 +265,7 @@ impl TransactionRunnerBuilder { pub struct TransactionRunner { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunner { @@ -271,13 +301,14 @@ impl TransactionRunner { let start_time = tokio::time::Instant::now(); let mut attempts: u32 = 0; let backoff = crate::transaction_retry_policy::default_retry_backoff(); + let deadline = self.timeout.map(|t| start_time + t); loop { attempts += 1; let mut current_tx_id = None; let attempt_result = async { - let transaction = self.builder.begin_transaction().await?; + let transaction = self.builder.begin_transaction(deadline).await?; current_tx_id = transaction.transaction_id().ok(); let result = match work(transaction.clone()).await {