From 6a12086f82582ef03ad2f660687bf5f1ecd1a589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 24 Apr 2026 10:52:27 +0200 Subject: [PATCH 1/6] feat(spanner): support timeout, retry- and backoff policy per statement Adds support for setting a timeout, retry policy, and backoff policy per statement and read. This allows an application to override the defaults for these options on a per statement-basis. Note: This change adds the API for all types of statements. However, the retry and backoff policies are not respected for the streaming RPCs ExecuteStreamingSql and StreamingRead yet. This will be added in a follow-up pull request. --- src/spanner/src/batch_dml.rs | 25 ++ src/spanner/src/client.rs | 263 +++++++++++++++++++++- src/spanner/src/read.rs | 60 ++++- src/spanner/src/read_only_transaction.rs | 30 ++- src/spanner/src/read_write_transaction.rs | 7 +- src/spanner/src/statement.rs | 60 ++++- 6 files changed, 423 insertions(+), 22 deletions(-) diff --git a/src/spanner/src/batch_dml.rs b/src/spanner/src/batch_dml.rs index 33f1a6bd80..8e70f67dc0 100644 --- a/src/spanner/src/batch_dml.rs +++ b/src/spanner/src/batch_dml.rs @@ -16,14 +16,19 @@ use crate::client::Statement; use crate::error::{BatchUpdateError, internal_error}; use crate::model::result_set_stats::RowCount; use crate::model::{ExecuteBatchDmlResponse, RequestOptions}; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; use google_cloud_gax::error::rpc::Code; use google_cloud_gax::error::rpc::Status as RpcStatus; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; +use std::time::Duration; /// A builder for [BatchDml]. #[derive(Clone, Default, Debug)] pub struct BatchDmlBuilder { statements: Vec, request_options: Option, + gax_options: GaxRequestOptions, } impl BatchDmlBuilder { @@ -59,11 +64,30 @@ impl BatchDmlBuilder { self } + /// Sets the timeout for this batch DML request. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this batch DML request. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this batch DML request. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds and returns the finalized BatchDml object. pub fn build(self) -> BatchDml { BatchDml { statements: self.statements, request_options: self.request_options, + gax_options: self.gax_options, } } } @@ -73,6 +97,7 @@ impl BatchDmlBuilder { pub struct BatchDml { pub(crate) statements: Vec, pub(crate) request_options: Option, + pub(crate) gax_options: GaxRequestOptions, } impl BatchDml { diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index be48d288b9..a59ad48785 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -250,7 +250,7 @@ mod tests { use super::*; use crate::model::CreateSessionRequest; use crate::result_set::tests::adapt; - use gaxi::grpc::tonic::{Response, Status}; + use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use google_cloud_gax::error::rpc::Code; use spanner_grpc_mock::google::rpc as mock_rpc; @@ -258,6 +258,9 @@ mod tests { use spanner_grpc_mock::google::spanner::v1::Session; use spanner_grpc_mock::{MockSpanner, start}; use static_assertions::{assert_impl_all, assert_not_impl_any}; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; #[test] fn auto_traits() { @@ -838,6 +841,264 @@ mod tests { Ok(()) } + #[tokio::test] + async fn timeout_respected() -> anyhow::Result<()> { + use crate::batch_dml::BatchDml; + use std::time::Duration; + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![42], + ..Default::default() + })) + }); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for query" + ); + + let (_tx, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::new(rx)) + }); + + mock.expect_streaming_read().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for read" + ); + + let (_tx, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::new(rx)) + }); + + mock.expect_execute_sql().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for single DML" + ); + + Ok(Response::new(mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + mock.expect_execute_batch_dml().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for batch dml" + ); + + Ok(Response::new(mock_v1::ExecuteBatchDmlResponse { + result_sets: vec![mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + })) + }); + + mock.expect_commit().returning(|_| { + Ok(Response::new(mock_v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1234, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 2. Start mock server + let (address, _server) = start("0.0.0.0:0", mock).await?; + + // 3. Configure Client + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + let runner = db.read_write_transaction().build().await?; + + // 4. Run transaction + runner + .run(async |tx| { + // Query + let stmt = Statement::builder("SELECT 1") + .with_timeout(Duration::from_secs(10)) + .build(); + let _ = tx.execute_query(stmt).await?; + + // Read + let req = ReadRequest::builder("Table", vec!["Col"]) + .with_keys(crate::key::KeySet::all()) + .with_timeout(Duration::from_secs(5)) + .build(); + let _ = tx.execute_read(req).await?; + + // Single DML + let dml = Statement::builder("UPDATE t SET c = 1") + .with_timeout(Duration::from_secs(7)) + .build(); + let _ = tx.execute_update(dml).await?; + + // Batch DML + let batch = BatchDml::builder() + .add_statement("UPDATE t SET c = 2") + .with_timeout(Duration::from_secs(8)) + .build(); + let _ = tx.execute_batch_update(batch).await?; + + Ok(()) + }) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn retry_policy_respected() -> anyhow::Result<()> { + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; + use google_cloud_gax::retry_state::RetryState; + + // Extend the default retry policy to also retry on ResourceExhausted. + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + // Custom Backoff Policy with a counter that allows us to verify that + // it was used. + #[derive(Debug)] + struct ConstantBackoff(Duration, Arc); + + impl BackoffPolicy for ConstantBackoff { + fn on_failure(&self, _state: &RetryState) -> Duration { + self.1.fetch_add(1, Ordering::SeqCst); + self.0 + } + } + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![42], + ..Default::default() + })) + }); + + // Mock ExecuteSql to first return RESOURCE_EXHAUSTED and then succeed. + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded"))); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + mock.expect_commit().returning(|_| { + Ok(Response::new(mock_v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1234, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 2. Start mock server + let (address, _server) = start("0.0.0.0:0", mock).await?; + + // 3. Configure Client + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + let runner = db.read_write_transaction().build().await?; + + // 4. Call execute_update with custom retry and backoff + let backoff_count = Arc::new(AtomicUsize::new(0)); + let stmt = Statement::builder("UPDATE t SET c = 1") + .with_retry_policy(retry_policy) + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + backoff_count.clone(), + )) + .build(); + + let result = runner + .run(async |tx| { + let count = tx.execute_update(stmt.clone()).await?; + Ok(count) + }) + .await?; + + // 5. Verify success after retry and that backoff was called + assert_eq!(result, 1); + assert_eq!( + backoff_count.load(Ordering::SeqCst), + 1, + "Backoff policy should have been called once" + ); + + Ok(()) + } + #[test] fn test_parse_emulator_endpoint() { assert_eq!( diff --git a/src/spanner/src/read.rs b/src/spanner/src/read.rs index 4e086372b8..857db4d94a 100644 --- a/src/spanner/src/read.rs +++ b/src/spanner/src/read.rs @@ -14,6 +14,10 @@ use crate::key::KeySet; use crate::model::DirectedReadOptions; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; +use std::time::Duration; /// Represents an incomplete read operation that requires specifying keys. /// @@ -63,6 +67,7 @@ impl ReadRequestBuilder { limit: None, request_options: None, directed_read_options: None, + gax_options: GaxRequestOptions::default(), } } @@ -90,12 +95,13 @@ impl ReadRequestBuilder { limit: None, request_options: None, directed_read_options: None, + gax_options: GaxRequestOptions::default(), } } } /// A fully configured read request that is ready to be built. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct ConfiguredReadRequestBuilder { table: String, index: Option, @@ -104,6 +110,7 @@ pub struct ConfiguredReadRequestBuilder { limit: Option, request_options: Option, directed_read_options: Option, + gax_options: GaxRequestOptions, } impl ConfiguredReadRequestBuilder { @@ -163,6 +170,24 @@ impl ConfiguredReadRequestBuilder { self } + /// Sets the timeout for this read request. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this read request. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this read request. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds the configured `ReadRequest`. pub fn build(self) -> ReadRequest { ReadRequest { @@ -173,6 +198,7 @@ impl ConfiguredReadRequestBuilder { limit: self.limit, request_options: self.request_options, directed_read_options: self.directed_read_options, + gax_options: self.gax_options, } } } @@ -181,7 +207,7 @@ impl ConfiguredReadRequestBuilder { /// /// Contains the table, optional index, keys, and columns. /// Allows configuring optional parameters on the read operation, such as a limit. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct ReadRequest { pub(crate) table: String, pub(crate) index: Option, @@ -190,6 +216,7 @@ pub struct ReadRequest { pub(crate) limit: Option, pub(crate) request_options: Option, pub(crate) directed_read_options: Option, + pub(crate) gax_options: GaxRequestOptions, } impl ReadRequest { @@ -238,9 +265,9 @@ mod tests { #[test] fn auto_traits() { - static_assertions::assert_impl_all!(ReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug, PartialEq); - static_assertions::assert_impl_all!(ConfiguredReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug, PartialEq); - static_assertions::assert_impl_all!(ReadRequest: Send, Sync, Clone, std::fmt::Debug, PartialEq); + static_assertions::assert_impl_all!(ReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug); + static_assertions::assert_impl_all!(ConfiguredReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug); + static_assertions::assert_impl_all!(ReadRequest: Send, Sync, Clone, std::fmt::Debug); } #[test] @@ -301,4 +328,27 @@ mod tests { .build(); assert_eq!(req.directed_read_options, Some(dro)); } + + #[test] + fn with_gax_options() -> anyhow::Result<()> { + use google_cloud_gax::exponential_backoff::ExponentialBackoff; + use google_cloud_gax::retry_policy::NeverRetry; + use std::time::Duration; + + let req = ReadRequest::builder("MyTable", vec!["col1"]) + .with_keys(KeySet::all()) + .with_timeout(Duration::from_secs(10)) + .with_retry_policy(NeverRetry) + .with_backoff_policy(ExponentialBackoff::default()) + .build(); + + assert_eq!( + req.gax_options.attempt_timeout(), + &Some(Duration::from_secs(10)) + ); + assert!(req.gax_options.retry_policy().is_some()); + assert!(req.gax_options.backoff_policy().is_some()); + + Ok(()) + } } diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index b4e0ce8c58..a7f3a1356d 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -579,12 +579,11 @@ impl ReadContext { /// Helper macro to execute a streaming SQL or streaming read RPC with retry logic. macro_rules! execute_stream_with_retry { - ($self:expr, $request:ident, $rpc_method:ident, $operation_variant:path) => {{ + ($self:expr, $request:ident, $gax_options:ident, $rpc_method:ident, $operation_variant:path) => {{ let stream = match $self .client .spanner - // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method($request.clone(), $gax_options.clone()) .send() .await { @@ -595,8 +594,7 @@ macro_rules! execute_stream_with_retry { $self .client .spanner - // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method($request.clone(), $gax_options.clone()) .send() .await? } else { @@ -621,28 +619,42 @@ impl ReadContext { &self, statement: T, ) -> crate::Result { + let statement = statement.into(); + let gax_options = statement.gax_options().clone(); let mut request = statement - .into() .into_request() .set_session(self.session_name.clone()) .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) + execute_stream_with_retry!( + self, + request, + gax_options, + execute_streaming_sql, + StreamOperation::Query + ) } pub(crate) async fn execute_read>( &self, read: T, ) -> crate::Result { + let read = read.into(); + let gax_options = read.gax_options.clone(); let mut request = read - .into() .into_request() .set_session(self.session_name.clone()) .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) + execute_stream_with_retry!( + self, + request, + gax_options, + streaming_read, + StreamOperation::Read + ) } } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index fd3c7b28cc..3e7bc9105d 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -164,8 +164,9 @@ impl ReadWriteTransaction { /// Executes an update using this transaction. pub async fn execute_update>(&self, statement: T) -> crate::Result { let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); + let statement = statement.into(); + let gax_options = statement.gax_options().clone(); let mut request = statement - .into() .into_request() .set_session(self.context.session_name.clone()) .set_transaction(self.context.transaction_selector.selector()) @@ -176,7 +177,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_sql(request, RequestOptions::default()) + .execute_sql(request, gax_options) .await?; self.context .precommit_token_tracker @@ -280,7 +281,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_batch_dml(request, RequestOptions::default()) + .execute_batch_dml(request, batch.gax_options) .await; match response_result { diff --git a/src/spanner/src/statement.rs b/src/spanner/src/statement.rs index f3d882b5b5..e10309c1d4 100644 --- a/src/spanner/src/statement.rs +++ b/src/spanner/src/statement.rs @@ -18,7 +18,11 @@ use crate::model::execute_sql_request::QueryOptions; use crate::to_value::ToValue; use crate::types::Type; use crate::value::Value; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; use std::collections::BTreeMap; +use std::time::Duration; /// A builder for [Statement]. /// @@ -29,7 +33,7 @@ use std::collections::BTreeMap; /// .add_param("id", &42) /// .build(); /// ``` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct StatementBuilder { sql: String, params: BTreeMap, @@ -38,6 +42,7 @@ pub struct StatementBuilder { directed_read_options: Option, query_options: Option, query_mode: Option, + gax_options: GaxRequestOptions, } impl StatementBuilder { @@ -50,6 +55,7 @@ impl StatementBuilder { directed_read_options: None, query_options: None, query_mode: None, + gax_options: GaxRequestOptions::default(), } } @@ -149,6 +155,24 @@ impl StatementBuilder { self } + /// Sets the timeout for this statement. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this statement. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this statement. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds and returns the finalized Statement object. pub fn build(self) -> Statement { Statement { @@ -159,6 +183,7 @@ impl StatementBuilder { directed_read_options: self.directed_read_options, query_options: self.query_options, query_mode: self.query_mode, + gax_options: self.gax_options, } } } @@ -186,7 +211,7 @@ impl StatementBuilder { /// # Ok(()) /// # } /// ``` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct Statement { pub sql: String, pub(crate) params: BTreeMap, @@ -195,6 +220,7 @@ pub struct Statement { pub(crate) directed_read_options: Option, pub(crate) query_options: Option, pub(crate) query_mode: Option, + gax_options: GaxRequestOptions, } impl Statement { @@ -203,6 +229,10 @@ impl Statement { StatementBuilder::new(sql) } + pub(crate) fn gax_options(&self) -> &GaxRequestOptions { + &self.gax_options + } + /// Sets the query mode to use for this statement. /// /// # Example @@ -308,8 +338,8 @@ mod tests { #[test] fn test_auto_traits() { - static_assertions::assert_impl_all!(Statement: Clone, std::fmt::Debug, PartialEq, Send, Sync); - static_assertions::assert_impl_all!(StatementBuilder: Clone, std::fmt::Debug, PartialEq, Send, Sync); + static_assertions::assert_impl_all!(Statement: Clone, std::fmt::Debug, Send, Sync); + static_assertions::assert_impl_all!(StatementBuilder: Clone, std::fmt::Debug, Send, Sync); } #[test] @@ -476,4 +506,26 @@ mod tests { assert_eq!(req.query_mode, QueryMode::Profile); Ok(()) } + + #[test] + fn with_gax_options() -> anyhow::Result<()> { + use google_cloud_gax::exponential_backoff::ExponentialBackoff; + use google_cloud_gax::retry_policy::NeverRetry; + use std::time::Duration; + + let stmt = Statement::builder("SELECT * FROM users") + .with_timeout(Duration::from_secs(10)) + .with_retry_policy(NeverRetry) + .with_backoff_policy(ExponentialBackoff::default()) + .build(); + + assert_eq!( + stmt.gax_options.attempt_timeout(), + &Some(Duration::from_secs(10)) + ); + assert!(stmt.gax_options.retry_policy().is_some()); + assert!(stmt.gax_options.backoff_policy().is_some()); + + Ok(()) + } } From 65afc6ba27b0a83597a682c9926b03498a4cfbb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 24 Apr 2026 11:44:38 +0200 Subject: [PATCH 2/6] feat(spanner): support timeout, retry policy and backoff policy for streaming RPCs Add support for custom timeouts, retry policies, and backoff policies for streaming RPCs that return a ResultSet. This also refactors ResultSet to store its internal defaults as a RetryPolicy and BackoffPolicy, instead of hardcoded custom values. --- .../src/batch_read_only_transaction.rs | 3 + src/spanner/src/read_only_transaction.rs | 1 + src/spanner/src/result_set.rs | 234 +++++++++++++++++- 3 files changed, 225 insertions(+), 13 deletions(-) diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 788e4b22ef..95911d4302 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -21,6 +21,7 @@ use crate::read_only_transaction::{ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; use serde::{Deserialize, Serialize}; /// A builder for [BatchReadOnlyTransaction]. @@ -345,6 +346,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), + GaxRequestOptions::default(), )) } @@ -368,6 +370,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), + GaxRequestOptions::default(), )) } } diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index a7f3a1356d..04e877e90f 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -610,6 +610,7 @@ macro_rules! execute_stream_with_retry { $self.client.clone(), $self.session_name.clone(), $operation_variant($request), + $gax_options, )) }}; } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 8832fb6209..fcacf408d5 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -23,9 +23,16 @@ use crate::row::Row; use crate::server_streaming::stream::PartialResultSetStream; use bytes::Bytes; use gaxi::prost::FromProto; -use google_cloud_gax::error::rpc::Code; +use google_cloud_gax::backoff_policy::BackoffPolicy; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; +use google_cloud_gax::retry_state::RetryState; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use tokio::time::sleep; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -64,6 +71,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + gax_options: GaxRequestOptions, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -95,7 +103,9 @@ impl ResultSet { client: DatabaseClient, session_name: String, operation: StreamOperation, + gax_options: GaxRequestOptions, ) -> Self { + let gax_options = Self::apply_defaults(gax_options); Self { stream: Some(stream), buffered_values: Vec::new(), @@ -114,9 +124,24 @@ impl ResultSet { retry_count: 0, transaction_selector, stats: None, + gax_options, } } + fn apply_defaults(mut gax_options: GaxRequestOptions) -> GaxRequestOptions { + if gax_options.retry_policy().is_none() { + gax_options.set_retry_policy(Aip194Strict.with_attempt_limit(10)); + } + if gax_options.backoff_policy().is_none() { + gax_options.set_backoff_policy(Self::default_backoff_policy()); + } + gax_options + } + + fn default_backoff_policy() -> Arc { + Arc::new(ExponentialBackoffBuilder::default().clamp()) + } + /// Returns the metadata of the result set. /// /// # Example @@ -290,6 +315,15 @@ impl ResultSet { // Clear the buffer and restart the stream using the last // resume_token that we have seen. self.partial_result_sets_buffer.clear(); + + // Apply backoff delay if policy is present + if let Some(policy) = self.gax_options.backoff_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + let delay = policy.on_failure(&state); + sleep(delay).await; + } + self.restart_stream().await?; return Ok(()); } @@ -461,7 +495,7 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -474,7 +508,7 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -483,13 +517,17 @@ impl ResultSet { Ok(()) } - // TODO(#5185): Make the retry policy configurable. fn should_retry(&self, e: &crate::Error) -> bool { - if self.retry_count >= 10 { - return false; + if let Some(policy) = self.gax_options.retry_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + + if let Some(status) = e.status() { + let gax_error = GaxError::service(status.clone()); + return policy.on_error(&state, gax_error).is_continue(); + } } - e.status() - .is_some_and(|status| status.code == Code::Unavailable) + false } /// Converts the [`ResultSet`] into a [`Stream`]. @@ -586,7 +624,10 @@ impl ResultSet { pub(crate) mod tests { use super::*; use crate::client::Spanner; - use gaxi::grpc::tonic::Response; + use crate::client::Statement; + use crate::key::KeySet; + use crate::read::ReadRequest; + use gaxi::grpc::tonic::{Code as GrpcCode, Response}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use prost_types::Value; use spanner_grpc_mock::MockSpanner; @@ -596,6 +637,20 @@ pub(crate) mod tests { PartialResultSet, ResultSetMetadata, Session, StructType, }; use spanner_grpc_mock::start; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; + + /// A backoff policy that always returns the same duration and that contains a counter that + /// can be used to verify that it was called. + #[derive(Debug)] + pub(crate) struct ConstantBackoff(pub(crate) Duration, pub(crate) Arc); + impl google_cloud_gax::backoff_policy::BackoffPolicy for ConstantBackoff { + fn on_failure(&self, _state: &RetryState) -> Duration { + self.1.fetch_add(1, Ordering::SeqCst); + self.0 + } + } pub(crate) fn string_val(s: &str) -> Value { Value { @@ -863,6 +918,27 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_default_policies_applied() -> anyhow::Result<()> { + let rs = run_mock_query(vec![PartialResultSet { + metadata: metadata(2), + last: true, + ..Default::default() + }]) + .await; + + assert!( + rs.gax_options.retry_policy().is_some(), + "Default retry policy should be applied" + ); + assert!( + rs.gax_options.backoff_policy().is_some(), + "Default backoff policy should be applied" + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_retry_read_stream() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; @@ -922,17 +998,118 @@ pub(crate) mod tests { let tx = db_client.single_use().build(); let read_req = crate::read::ReadRequest::builder("table", vec!["Id", "Value"]) .with_keys(crate::key::KeySet::all()) + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + Arc::new(AtomicUsize::new(0)), + )) + .build(); + let mut rs: ResultSet = tx.execute_read(read_req).await?; + + let row1 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row1.raw_values()[0].0, string_val("row1")); + + let row2 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row2.raw_values()[0].0, string_val("row2")); + + assert!(rs.next().await.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_custom_retry_policy() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + use google_cloud_gax::retry_policy::Aip194Strict; + use google_cloud_gax::retry_policy::RetryPolicyExt; + + // Extend the default retry policy to also retry on ResourceExhausted. + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Fail with RESOURCE_EXHAUSTED on first call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([ + Ok(PartialResultSet { + metadata: metadata(2), + values: vec![string_val("row1"), string_val("b")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::new(GrpcCode::ResourceExhausted, "Quota exceeded")), + ]); + Ok(Response::from(stream)) + }); + + // Succeed on second call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([Ok(PartialResultSet { + values: vec![string_val("row2"), string_val("d")], + resume_token: b"token2".to_vec(), + last: true, + ..Default::default() + })]); + Ok(Response::from(stream)) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let backoff_count = Arc::new(AtomicUsize::new(0)); + let read_req = ReadRequest::builder("table", vec!["Id", "Value"]) + .with_keys(KeySet::all()) + .with_retry_policy(retry_policy) + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + backoff_count.clone(), + )) .build(); + let mut rs: ResultSet = tx.execute_read(read_req).await?; let row1 = rs.next().await.expect("Stream ended unexpectedly")?; assert_eq!(row1.raw_values()[0].0, string_val("row1")); + // This next() call should trigger the retry because the previous stream ended with error! let row2 = rs.next().await.expect("Stream ended unexpectedly")?; assert_eq!(row2.raw_values()[0].0, string_val("row2")); assert!(rs.next().await.is_none()); + assert_eq!( + backoff_count.load(Ordering::SeqCst), + 1, + "Backoff policy should have been called once" + ); + Ok(()) } @@ -1364,7 +1541,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + Arc::new(AtomicUsize::new(0)), + )) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Stream ended unexpectedly")?; assert_eq!(row1.raw_values()[0].0, string_val("row1")); @@ -1571,7 +1754,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + Arc::new(AtomicUsize::new(0)), + )) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Expected row1")?; assert_eq!(row1.raw_values()[0].0, string_val("row1_retry")); @@ -1646,7 +1835,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + Arc::new(AtomicUsize::new(0)), + )) + .build(); + let mut rs = tx.execute_query(stmt).await?; // Set max buffer size to 3 (so 2 messages is under the limit) rs.set_max_buffered_partial_result_sets(3); @@ -1695,7 +1890,14 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let backoff_count = Arc::new(AtomicUsize::new(0)); + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + backoff_count.clone(), + )) + .build(); + let mut rs = tx.execute_query(stmt).await?; let res = rs.next().await; assert!(res.is_some(), "Expected an error but got None"); @@ -1708,6 +1910,12 @@ pub(crate) mod tests { err_str ); + assert_eq!( + backoff_count.load(Ordering::SeqCst), + 10, + "Backoff policy should have been called 10 times" + ); + Ok(()) } From bfe7ace7a85e6a72856a39712a32ef2eef43d8d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 24 Apr 2026 10:52:27 +0200 Subject: [PATCH 3/6] feat(spanner): support timeout, retry- and backoff policy per statement Adds support for setting a timeout, retry policy, and backoff policy per statement and read. This allows an application to override the defaults for these options on a per statement-basis. Note: This change adds the API for all types of statements. However, the retry and backoff policies are not respected for the streaming RPCs ExecuteStreamingSql and StreamingRead yet. This will be added in a follow-up pull request. --- src/spanner/src/batch_dml.rs | 57 +++++ src/spanner/src/client.rs | 263 +++++++++++++++++++++- src/spanner/src/read.rs | 60 ++++- src/spanner/src/read_only_transaction.rs | 30 ++- src/spanner/src/read_write_transaction.rs | 7 +- src/spanner/src/statement.rs | 60 ++++- 6 files changed, 455 insertions(+), 22 deletions(-) diff --git a/src/spanner/src/batch_dml.rs b/src/spanner/src/batch_dml.rs index 33f1a6bd80..2e57be576b 100644 --- a/src/spanner/src/batch_dml.rs +++ b/src/spanner/src/batch_dml.rs @@ -16,14 +16,19 @@ use crate::client::Statement; use crate::error::{BatchUpdateError, internal_error}; use crate::model::result_set_stats::RowCount; use crate::model::{ExecuteBatchDmlResponse, RequestOptions}; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; use google_cloud_gax::error::rpc::Code; use google_cloud_gax::error::rpc::Status as RpcStatus; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; +use std::time::Duration; /// A builder for [BatchDml]. #[derive(Clone, Default, Debug)] pub struct BatchDmlBuilder { statements: Vec, request_options: Option, + gax_options: GaxRequestOptions, } impl BatchDmlBuilder { @@ -59,11 +64,30 @@ impl BatchDmlBuilder { self } + /// Sets the per-attempt timeout for this batch DML request. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this batch DML request. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this batch DML request. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds and returns the finalized BatchDml object. pub fn build(self) -> BatchDml { BatchDml { statements: self.statements, request_options: self.request_options, + gax_options: self.gax_options, } } } @@ -73,6 +97,7 @@ impl BatchDmlBuilder { pub struct BatchDml { pub(crate) statements: Vec, pub(crate) request_options: Option, + pub(crate) gax_options: GaxRequestOptions, } impl BatchDml { @@ -148,6 +173,38 @@ mod tests { ); } + #[test] + fn builder_with_gax_options() { + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_policy::Aip194Strict; + use google_cloud_gax::retry_state::RetryState; + use std::time::Duration; + + #[derive(Debug)] + struct DummyBackoff; + impl BackoffPolicy for DummyBackoff { + fn on_failure(&self, _state: &RetryState) -> Duration { + Duration::ZERO + } + } + + let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build(); + + let batch = BatchDml::builder() + .add_statement(stmt) + .with_attempt_timeout(Duration::from_secs(5)) + .with_retry_policy(Aip194Strict) + .with_backoff_policy(DummyBackoff) + .build(); + + assert_eq!( + *batch.gax_options.attempt_timeout(), + Some(Duration::from_secs(5)) + ); + assert!(batch.gax_options.retry_policy().is_some()); + assert!(batch.gax_options.backoff_policy().is_some()); + } + #[test] fn builder_with_request_tag() { let stmt = Statement::builder("UPDATE t SET c = 1 WHERE id = 1").build(); diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index be48d288b9..52a4319ff2 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -250,7 +250,7 @@ mod tests { use super::*; use crate::model::CreateSessionRequest; use crate::result_set::tests::adapt; - use gaxi::grpc::tonic::{Response, Status}; + use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use google_cloud_gax::error::rpc::Code; use spanner_grpc_mock::google::rpc as mock_rpc; @@ -258,6 +258,9 @@ mod tests { use spanner_grpc_mock::google::spanner::v1::Session; use spanner_grpc_mock::{MockSpanner, start}; use static_assertions::{assert_impl_all, assert_not_impl_any}; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; #[test] fn auto_traits() { @@ -838,6 +841,264 @@ mod tests { Ok(()) } + #[tokio::test] + async fn timeout_respected() -> anyhow::Result<()> { + use crate::batch_dml::BatchDml; + use std::time::Duration; + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![42], + ..Default::default() + })) + }); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for query" + ); + + let (_tx, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::new(rx)) + }); + + mock.expect_streaming_read().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for read" + ); + + let (_tx, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::new(rx)) + }); + + mock.expect_execute_sql().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for single DML" + ); + + Ok(Response::new(mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + mock.expect_execute_batch_dml().once().returning(|req| { + let metadata = req.metadata(); + let timeout = metadata.get("grpc-timeout"); + assert!( + timeout.is_some(), + "grpc-timeout header should be present for batch dml" + ); + + Ok(Response::new(mock_v1::ExecuteBatchDmlResponse { + result_sets: vec![mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + })) + }); + + mock.expect_commit().returning(|_| { + Ok(Response::new(mock_v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1234, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 2. Start mock server + let (address, _server) = start("0.0.0.0:0", mock).await?; + + // 3. Configure Client + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + let runner = db.read_write_transaction().build().await?; + + // 4. Run transaction + runner + .run(async |tx| { + // Query + let stmt = Statement::builder("SELECT 1") + .with_attempt_timeout(Duration::from_secs(10)) + .build(); + let _ = tx.execute_query(stmt).await?; + + // Read + let req = ReadRequest::builder("Table", vec!["Col"]) + .with_keys(crate::key::KeySet::all()) + .with_attempt_timeout(Duration::from_secs(5)) + .build(); + let _ = tx.execute_read(req).await?; + + // Single DML + let dml = Statement::builder("UPDATE t SET c = 1") + .with_attempt_timeout(Duration::from_secs(7)) + .build(); + let _ = tx.execute_update(dml).await?; + + // Batch DML + let batch = BatchDml::builder() + .add_statement("UPDATE t SET c = 2") + .with_attempt_timeout(Duration::from_secs(8)) + .build(); + let _ = tx.execute_batch_update(batch).await?; + + Ok(()) + }) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn retry_policy_respected() -> anyhow::Result<()> { + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; + use google_cloud_gax::retry_state::RetryState; + + // Extend the default retry policy to also retry on ResourceExhausted. + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + // Custom Backoff Policy with a counter that allows us to verify that + // it was used. + #[derive(Debug)] + struct ConstantBackoff(Duration, Arc); + + impl BackoffPolicy for ConstantBackoff { + fn on_failure(&self, _state: &RetryState) -> Duration { + self.1.fetch_add(1, Ordering::SeqCst); + self.0 + } + } + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![42], + ..Default::default() + })) + }); + + // Mock ExecuteSql to first return RESOURCE_EXHAUSTED and then succeed. + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded"))); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::ResultSet { + stats: Some(mock_v1::ResultSetStats { + row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + mock.expect_commit().returning(|_| { + Ok(Response::new(mock_v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1234, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 2. Start mock server + let (address, _server) = start("0.0.0.0:0", mock).await?; + + // 3. Configure Client + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + let runner = db.read_write_transaction().build().await?; + + // 4. Call execute_update with custom retry and backoff + let backoff_count = Arc::new(AtomicUsize::new(0)); + let stmt = Statement::builder("UPDATE t SET c = 1") + .with_retry_policy(retry_policy) + .with_backoff_policy(ConstantBackoff( + Duration::from_nanos(1), + backoff_count.clone(), + )) + .build(); + + let result = runner + .run(async |tx| { + let count = tx.execute_update(stmt.clone()).await?; + Ok(count) + }) + .await?; + + // 5. Verify success after retry and that backoff was called + assert_eq!(result, 1); + assert_eq!( + backoff_count.load(Ordering::SeqCst), + 1, + "Backoff policy should have been called once" + ); + + Ok(()) + } + #[test] fn test_parse_emulator_endpoint() { assert_eq!( diff --git a/src/spanner/src/read.rs b/src/spanner/src/read.rs index 4e086372b8..5ee380e7ea 100644 --- a/src/spanner/src/read.rs +++ b/src/spanner/src/read.rs @@ -14,6 +14,10 @@ use crate::key::KeySet; use crate::model::DirectedReadOptions; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; +use std::time::Duration; /// Represents an incomplete read operation that requires specifying keys. /// @@ -63,6 +67,7 @@ impl ReadRequestBuilder { limit: None, request_options: None, directed_read_options: None, + gax_options: GaxRequestOptions::default(), } } @@ -90,12 +95,13 @@ impl ReadRequestBuilder { limit: None, request_options: None, directed_read_options: None, + gax_options: GaxRequestOptions::default(), } } } /// A fully configured read request that is ready to be built. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct ConfiguredReadRequestBuilder { table: String, index: Option, @@ -104,6 +110,7 @@ pub struct ConfiguredReadRequestBuilder { limit: Option, request_options: Option, directed_read_options: Option, + gax_options: GaxRequestOptions, } impl ConfiguredReadRequestBuilder { @@ -163,6 +170,24 @@ impl ConfiguredReadRequestBuilder { self } + /// Sets the per-attempt timeout for this read request. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this read request. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this read request. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds the configured `ReadRequest`. pub fn build(self) -> ReadRequest { ReadRequest { @@ -173,6 +198,7 @@ impl ConfiguredReadRequestBuilder { limit: self.limit, request_options: self.request_options, directed_read_options: self.directed_read_options, + gax_options: self.gax_options, } } } @@ -181,7 +207,7 @@ impl ConfiguredReadRequestBuilder { /// /// Contains the table, optional index, keys, and columns. /// Allows configuring optional parameters on the read operation, such as a limit. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct ReadRequest { pub(crate) table: String, pub(crate) index: Option, @@ -190,6 +216,7 @@ pub struct ReadRequest { pub(crate) limit: Option, pub(crate) request_options: Option, pub(crate) directed_read_options: Option, + pub(crate) gax_options: GaxRequestOptions, } impl ReadRequest { @@ -238,9 +265,9 @@ mod tests { #[test] fn auto_traits() { - static_assertions::assert_impl_all!(ReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug, PartialEq); - static_assertions::assert_impl_all!(ConfiguredReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug, PartialEq); - static_assertions::assert_impl_all!(ReadRequest: Send, Sync, Clone, std::fmt::Debug, PartialEq); + static_assertions::assert_impl_all!(ReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug); + static_assertions::assert_impl_all!(ConfiguredReadRequestBuilder: Send, Sync, Clone, std::fmt::Debug); + static_assertions::assert_impl_all!(ReadRequest: Send, Sync, Clone, std::fmt::Debug); } #[test] @@ -301,4 +328,27 @@ mod tests { .build(); assert_eq!(req.directed_read_options, Some(dro)); } + + #[test] + fn with_gax_options() -> anyhow::Result<()> { + use google_cloud_gax::exponential_backoff::ExponentialBackoff; + use google_cloud_gax::retry_policy::NeverRetry; + use std::time::Duration; + + let req = ReadRequest::builder("MyTable", vec!["col1"]) + .with_keys(KeySet::all()) + .with_attempt_timeout(Duration::from_secs(10)) + .with_retry_policy(NeverRetry) + .with_backoff_policy(ExponentialBackoff::default()) + .build(); + + assert_eq!( + req.gax_options.attempt_timeout(), + &Some(Duration::from_secs(10)) + ); + assert!(req.gax_options.retry_policy().is_some()); + assert!(req.gax_options.backoff_policy().is_some()); + + Ok(()) + } } diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index b4e0ce8c58..a7f3a1356d 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -579,12 +579,11 @@ impl ReadContext { /// Helper macro to execute a streaming SQL or streaming read RPC with retry logic. macro_rules! execute_stream_with_retry { - ($self:expr, $request:ident, $rpc_method:ident, $operation_variant:path) => {{ + ($self:expr, $request:ident, $gax_options:ident, $rpc_method:ident, $operation_variant:path) => {{ let stream = match $self .client .spanner - // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method($request.clone(), $gax_options.clone()) .send() .await { @@ -595,8 +594,7 @@ macro_rules! execute_stream_with_retry { $self .client .spanner - // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method($request.clone(), $gax_options.clone()) .send() .await? } else { @@ -621,28 +619,42 @@ impl ReadContext { &self, statement: T, ) -> crate::Result { + let statement = statement.into(); + let gax_options = statement.gax_options().clone(); let mut request = statement - .into() .into_request() .set_session(self.session_name.clone()) .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) + execute_stream_with_retry!( + self, + request, + gax_options, + execute_streaming_sql, + StreamOperation::Query + ) } pub(crate) async fn execute_read>( &self, read: T, ) -> crate::Result { + let read = read.into(); + let gax_options = read.gax_options.clone(); let mut request = read - .into() .into_request() .set_session(self.session_name.clone()) .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) + execute_stream_with_retry!( + self, + request, + gax_options, + streaming_read, + StreamOperation::Read + ) } } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index fd3c7b28cc..3e7bc9105d 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -164,8 +164,9 @@ impl ReadWriteTransaction { /// Executes an update using this transaction. pub async fn execute_update>(&self, statement: T) -> crate::Result { let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); + let statement = statement.into(); + let gax_options = statement.gax_options().clone(); let mut request = statement - .into() .into_request() .set_session(self.context.session_name.clone()) .set_transaction(self.context.transaction_selector.selector()) @@ -176,7 +177,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_sql(request, RequestOptions::default()) + .execute_sql(request, gax_options) .await?; self.context .precommit_token_tracker @@ -280,7 +281,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_batch_dml(request, RequestOptions::default()) + .execute_batch_dml(request, batch.gax_options) .await; match response_result { diff --git a/src/spanner/src/statement.rs b/src/spanner/src/statement.rs index f3d882b5b5..fa12ad459c 100644 --- a/src/spanner/src/statement.rs +++ b/src/spanner/src/statement.rs @@ -18,7 +18,11 @@ use crate::model::execute_sql_request::QueryOptions; use crate::to_value::ToValue; use crate::types::Type; use crate::value::Value; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; use std::collections::BTreeMap; +use std::time::Duration; /// A builder for [Statement]. /// @@ -29,7 +33,7 @@ use std::collections::BTreeMap; /// .add_param("id", &42) /// .build(); /// ``` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct StatementBuilder { sql: String, params: BTreeMap, @@ -38,6 +42,7 @@ pub struct StatementBuilder { directed_read_options: Option, query_options: Option, query_mode: Option, + gax_options: GaxRequestOptions, } impl StatementBuilder { @@ -50,6 +55,7 @@ impl StatementBuilder { directed_read_options: None, query_options: None, query_mode: None, + gax_options: GaxRequestOptions::default(), } } @@ -149,6 +155,24 @@ impl StatementBuilder { self } + /// Sets the per-attempt timeout for this statement. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this statement. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this statement. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Builds and returns the finalized Statement object. pub fn build(self) -> Statement { Statement { @@ -159,6 +183,7 @@ impl StatementBuilder { directed_read_options: self.directed_read_options, query_options: self.query_options, query_mode: self.query_mode, + gax_options: self.gax_options, } } } @@ -186,7 +211,7 @@ impl StatementBuilder { /// # Ok(()) /// # } /// ``` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct Statement { pub sql: String, pub(crate) params: BTreeMap, @@ -195,6 +220,7 @@ pub struct Statement { pub(crate) directed_read_options: Option, pub(crate) query_options: Option, pub(crate) query_mode: Option, + gax_options: GaxRequestOptions, } impl Statement { @@ -203,6 +229,10 @@ impl Statement { StatementBuilder::new(sql) } + pub(crate) fn gax_options(&self) -> &GaxRequestOptions { + &self.gax_options + } + /// Sets the query mode to use for this statement. /// /// # Example @@ -308,8 +338,8 @@ mod tests { #[test] fn test_auto_traits() { - static_assertions::assert_impl_all!(Statement: Clone, std::fmt::Debug, PartialEq, Send, Sync); - static_assertions::assert_impl_all!(StatementBuilder: Clone, std::fmt::Debug, PartialEq, Send, Sync); + static_assertions::assert_impl_all!(Statement: Clone, std::fmt::Debug, Send, Sync); + static_assertions::assert_impl_all!(StatementBuilder: Clone, std::fmt::Debug, Send, Sync); } #[test] @@ -476,4 +506,26 @@ mod tests { assert_eq!(req.query_mode, QueryMode::Profile); Ok(()) } + + #[test] + fn with_gax_options() -> anyhow::Result<()> { + use google_cloud_gax::exponential_backoff::ExponentialBackoff; + use google_cloud_gax::retry_policy::NeverRetry; + use std::time::Duration; + + let stmt = Statement::builder("SELECT * FROM users") + .with_attempt_timeout(Duration::from_secs(10)) + .with_retry_policy(NeverRetry) + .with_backoff_policy(ExponentialBackoff::default()) + .build(); + + assert_eq!( + stmt.gax_options.attempt_timeout(), + &Some(Duration::from_secs(10)) + ); + assert!(stmt.gax_options.retry_policy().is_some()); + assert!(stmt.gax_options.backoff_policy().is_some()); + + Ok(()) + } } From da27cab9c17be74b115c5c2c7bb740ec081edeb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 24 Apr 2026 12:35:38 +0200 Subject: [PATCH 4/6] feat(spanner): support timeout, retry- and backoff policy for partitions Adds support for setting an attempt timeout and a retry- and backoff policy for a partition in a BatchReadOnlyTransaction. These settings are not serialized, which means that any host that wants to apply any of these settings to its execution of a partition must set these explicitly. --- .../src/batch_read_only_transaction.rs | 118 +++++++++++++++++- 1 file changed, 112 insertions(+), 6 deletions(-) diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 95911d4302..5124a04326 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -21,8 +21,11 @@ use crate::read_only_transaction::{ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; use serde::{Deserialize, Serialize}; +use std::time::Duration; /// A builder for [BatchReadOnlyTransaction]. /// @@ -173,6 +176,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -230,6 +234,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -242,6 +247,8 @@ impl BatchReadOnlyTransaction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Partition { pub(crate) inner: PartitionedOperation, + #[serde(skip)] + pub(crate) gax_options: GaxRequestOptions, } impl Partition { @@ -271,6 +278,30 @@ impl Partition { self } + /// Sets the per-attempt timeout for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own attempt timeout. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own retry policy. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own backoff policy. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Executes this partition and returns a [ResultSet] that /// contains the rows that belong to this partition. /// @@ -321,18 +352,23 @@ impl Partition { /// the database that the partitions belong to. pub async fn execute(&self, client: &DatabaseClient) -> crate::Result { match &self.inner { - PartitionedOperation::Query(req) => Self::execute_query(client, req).await, - PartitionedOperation::Read(req) => Self::execute_read(client, req).await, + PartitionedOperation::Query(req) => { + Self::execute_query(client, req, self.gax_options.clone()).await + } + PartitionedOperation::Read(req) => { + Self::execute_read(client, req, self.gax_options.clone()).await + } } } async fn execute_query( client: &DatabaseClient, req: &crate::model::ExecuteSqlRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), gax_options.clone()) .send() .await?; @@ -346,17 +382,18 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), - GaxRequestOptions::default(), + gax_options, )) } async fn execute_read( client: &DatabaseClient, req: &crate::model::ReadRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), gax_options.clone()) .send() .await?; @@ -370,7 +407,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), - GaxRequestOptions::default(), + gax_options, )) } } @@ -404,6 +441,69 @@ pub(crate) mod tests { assert_impl_all!(Partition: Send, Sync, Debug); } + #[test] + fn serialize_partition_skips_gax_options() -> anyhow::Result<()> { + use std::time::Duration; + + let req = crate::model::ExecuteSqlRequest::new() + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let mut gax_options = GaxRequestOptions::default(); + gax_options.set_attempt_timeout(Duration::from_secs(5)); + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options, + }; + + let serialized = serde_json::to_string(&partition)?; + let deserialized: Partition = serde_json::from_str(&serialized)?; + + // Verify that gax_options was NOT preserved (it uses default, which is None timeout) + assert_eq!(*deserialized.gax_options.attempt_timeout(), None); + + Ok(()) + } + + #[tokio::test] + async fn partition_execute_respects_options() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use std::time::Duration; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let timeout = req.metadata().get("grpc-timeout"); + assert!(timeout.is_some(), "Missing grpc-timeout header"); + assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros + + let (_, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::from(rx)) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let req = crate::model::ExecuteSqlRequest::new() + .set_session("projects/p/instances/i/databases/d/sessions/123") + .set_transaction(crate::model::TransactionSelector { + selector: Some(Selector::Id(b"tx_id_1".to_vec().into())), + ..Default::default() + }) + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), + }; + + let partition = partition.with_attempt_timeout(Duration::from_secs(5)); + + let _result_set = partition.execute(&db_client).await?; + + Ok(()) + } + #[test] fn serialize_partition_query() -> anyhow::Result<()> { let req = crate::model::ExecuteSqlRequest::new() @@ -419,6 +519,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -451,6 +552,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -501,6 +603,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -543,6 +646,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -703,6 +807,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; @@ -735,6 +840,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; From ce6622d310112410a01c3d9cf9a29336b7c66d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 24 Apr 2026 15:11:10 +0200 Subject: [PATCH 5/6] feat(spanner): support transaction timeout Adds support for setting a total timeout for a transaction. The timeout is applied to each RPC in the transaction, and 'ticks down' as time passes. If the transaction does not finish within the total timeout, it will fail with a DEADLINE_EXCEEDED error. Note that the timeout is NOT applied to the Rollback RPC, to allow a TransactionRunner or an application to rollback the transaction after a timeout error. --- src/spanner/src/client.rs | 246 +++++++++++++++++++++- src/spanner/src/read_write_transaction.rs | 121 +++++++++-- src/spanner/src/statement.rs | 6 + src/spanner/src/transaction_runner.rs | 33 ++- 4 files changed, 383 insertions(+), 23 deletions(-) diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index 52a4319ff2..abead8f1b1 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -250,16 +250,21 @@ mod tests { use super::*; use crate::model::CreateSessionRequest; use crate::result_set::tests::adapt; + use gaxi::grpc::tonic::MetadataMap; use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use google_cloud_gax::error::rpc::Code; use spanner_grpc_mock::google::rpc as mock_rpc; use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::google::spanner::v1::CommitResponse; + use spanner_grpc_mock::google::spanner::v1::ResultSet; + use spanner_grpc_mock::google::spanner::v1::ResultSetStats; use spanner_grpc_mock::google::spanner::v1::Session; + use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount; use spanner_grpc_mock::{MockSpanner, start}; use static_assertions::{assert_impl_all, assert_not_impl_any}; use std::sync::Arc; - use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::time::Duration; #[test] @@ -1099,6 +1104,245 @@ mod tests { Ok(()) } + fn parse_timeout(metadata: &MetadataMap) -> u64 { + let timeout = metadata + .get("grpc-timeout") + .expect("grpc-timeout header should be present"); + let timeout_str = timeout + .to_str() + .expect("grpc-timeout should be a valid string"); + if timeout_str.ends_with('u') { + timeout_str + .trim_end_matches('u') + .parse() + .expect("valid u64") + } else if timeout_str.ends_with('m') { + timeout_str + .trim_end_matches('m') + .parse::() + .expect("valid u64") + * 1000 + } else if timeout_str.ends_with('n') { + timeout_str + .trim_end_matches('n') + .parse::() + .expect("valid u64") + / 1000 + } else { + panic!("Unknown timeout unit in {}", timeout_str); + } + } + + #[tokio::test] + async fn transaction_timeout_respected() -> anyhow::Result<()> { + use spanner_grpc_mock::google::spanner::v1::Transaction; + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + + mock.expect_commit().once().returning(|_| { + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + // Mock execute_sql to check timeout header + mock.expect_execute_sql().once().returning(|req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 100000, + "Expected timeout to be <= 100ms, got {}", + timeout_val + ); + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + // 2. Initialize Client + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + // 3. Setup Transaction Runner with 100ms timeout + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(100)) + .build() + .await?; + + // 4. Run transaction and expect success + let result = runner + .run(async |tx| { + let stmt = Statement::builder("SELECT 1").build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + + #[tokio::test] + async fn transaction_timeout_ticks_down() -> anyhow::Result<()> { + use spanner_grpc_mock::google::spanner::v1::Transaction; + + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![1], + ..Default::default() + })) + }); + + let previous_timeout = Arc::new(AtomicU64::new(0)); + let prev_clone1 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 500000, + "Expected timeout to be <= 500ms, got {}", + timeout_val + ); + prev_clone1.store(timeout_val, Ordering::SeqCst); + Err(Status::new(GrpcCode::Aborted, "Aborted")) + }); + + // Second attempt: Checks that timeout is <= previous + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![2], + ..Default::default() + })) + }); + + let prev_clone2 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone2.load(Ordering::SeqCst); + assert!( + timeout_val <= prev, + "Timeout should tick down between attempts or be equal, got {} and {}", + timeout_val, + prev + ); + prev_clone2.store(timeout_val, Ordering::SeqCst); // store for next check + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + let prev_clone3 = previous_timeout.clone(); + mock.expect_commit().once().returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone3.load(Ordering::SeqCst); + assert!( + timeout_val < prev, + "Timeout should be smaller for commit, got {} and {}", + timeout_val, + prev + ); + + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(500)) + .build() + .await?; + + let result = runner + .run(async |tx| { + let stmt = Statement::builder("SELECT 1").build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + #[test] fn test_parse_emulator_endpoint() { assert_eq!( diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 3e7bc9105d..4eb6b73e77 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -33,8 +33,15 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::ReadContext; use crate::result_set::ResultSet; use crate::statement::Statement; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicy; +use google_cloud_gax::retry_result::RetryResult; +use google_cloud_gax::retry_state::RetryState; use std::sync::Arc; use std::sync::atomic::{AtomicI64, Ordering}; +use std::time::Duration as StdDuration; +use tokio::time::Instant; use wkt::Duration; /// A builder for [ReadWriteTransaction]. @@ -99,7 +106,10 @@ impl ReadWriteTransactionBuilder { self } - pub(crate) async fn begin_transaction(&self) -> crate::Result { + pub(crate) async fn begin_transaction( + &self, + deadline: Option, + ) -> crate::Result { let session_name = self.session_name.clone(); let mut request = BeginTransactionRequest::default() .set_session(session_name.clone()) @@ -111,10 +121,16 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let mut options = RequestOptions::default(); + if let Some(d) = deadline { + let remaining = d.saturating_duration_since(Instant::now()); + options.set_attempt_timeout(remaining); + } + let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, options) .await?; let transaction_selector = @@ -132,6 +148,7 @@ impl ReadWriteTransactionBuilder { }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, + deadline, }) } } @@ -142,6 +159,7 @@ pub struct ReadWriteTransaction { pub(crate) context: ReadContext, seqno: Arc, max_commit_delay: Option, + pub(crate) deadline: Option, } impl ReadWriteTransaction { @@ -150,7 +168,14 @@ impl ReadWriteTransaction { &self, statement: T, ) -> crate::Result { - self.context.execute_query(statement).await + if self.deadline.is_none() { + return self.context.execute_query(statement).await; + } + let stmt = statement.into(); + let mut gax_options = stmt.gax_options().clone(); + self.apply_transaction_timeout(&mut gax_options); + let stmt = stmt.with_gax_options(gax_options); + self.context.execute_query(stmt).await } /// Reads rows from the database using key lookups and scans, as a simple key/value style alternative to execute_query. @@ -158,14 +183,22 @@ impl ReadWriteTransaction { &self, read: T, ) -> crate::Result { - self.context.execute_read(read).await + if self.deadline.is_none() { + return self.context.execute_read(read).await; + } + let mut req = read.into(); + self.apply_transaction_timeout(&mut req.gax_options); + self.context.execute_read(req).await } /// Executes an update using this transaction. pub async fn execute_update>(&self, statement: T) -> crate::Result { - let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statement = statement.into(); - let gax_options = statement.gax_options().clone(); + let mut gax_options = statement.gax_options().clone(); + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut gax_options); + } + let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let mut request = statement .into_request() .set_session(self.context.session_name.clone()) @@ -260,6 +293,10 @@ impl ReadWriteTransaction { /// # } /// ``` pub async fn execute_batch_update(&self, batch: BatchDml) -> crate::Result> { + let mut batch = batch; + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut batch.gax_options); + } let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statements: Vec = batch @@ -313,11 +350,15 @@ impl ReadWriteTransaction { .set_or_clear_request_options(self.context.amend_request_options(None)) .set_or_clear_max_commit_delay(self.max_commit_delay); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + let response = self .context .client .spanner - .commit(request, RequestOptions::default()) + .commit(request, gax_options) .await?; let response = @@ -328,10 +369,14 @@ impl ReadWriteTransaction { .set_precommit_token(*new_precommit_token) .set_or_clear_request_options(self.context.amend_request_options(None)); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit(retry_commit_req, gax_options) .await? } else { response @@ -359,6 +404,40 @@ impl ReadWriteTransaction { Ok(()) } + + fn apply_transaction_timeout(&self, options: &mut GaxRequestOptions) { + if let Some(deadline) = self.deadline { + let inner_policy = options + .retry_policy() + .clone() + .unwrap_or_else(|| Arc::new(google_cloud_gax::retry_policy::Aip194Strict)); + let bounded_policy = TransactionBoundedRetryPolicy { + inner: inner_policy, + deadline, + }; + options.set_retry_policy(bounded_policy); + } + } +} + +/// A retry policy that wraps another policy and bounds the total execution time +/// by a specific transaction deadline. +/// +/// This policy delegates `on_error` to the inner policy but overrides `remaining_time` +/// to ensure that it never exceeds the time left until the transaction deadline. +#[derive(Debug)] +struct TransactionBoundedRetryPolicy { + inner: Arc, + deadline: Instant, +} + +impl RetryPolicy for TransactionBoundedRetryPolicy { + fn on_error(&self, state: &RetryState, error: GaxError) -> RetryResult { + self.inner.on_error(state, error) + } + fn remaining_time(&self, _state: &RetryState) -> Option { + Some(self.deadline.saturating_duration_since(Instant::now())) + } } #[cfg(test)] @@ -460,7 +539,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -527,7 +606,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); let count = tx @@ -564,7 +643,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -609,7 +688,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -668,7 +747,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -713,7 +792,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -771,7 +850,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -821,7 +900,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -873,7 +952,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_isolation_level(IsolationLevel::Serializable) .with_read_lock_mode(ReadLockMode::Pessimistic) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -897,7 +976,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_exclude_txn_from_change_streams(true) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -957,7 +1036,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1031,7 +1110,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1072,7 +1151,7 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_max_commit_delay(Duration::new(0, 200_000_000).unwrap()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); diff --git a/src/spanner/src/statement.rs b/src/spanner/src/statement.rs index fa12ad459c..90be21d7fe 100644 --- a/src/spanner/src/statement.rs +++ b/src/spanner/src/statement.rs @@ -233,6 +233,12 @@ impl Statement { &self.gax_options } + /// Returns a new `Statement` with the given `GaxRequestOptions`. + pub(crate) fn with_gax_options(mut self, options: GaxRequestOptions) -> Self { + self.gax_options = options; + self + } + /// Sets the query mode to use for this statement. /// /// # Example diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index b96b0da549..820cc5ee37 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -19,6 +19,7 @@ use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBu use crate::transaction_retry_policy::{ BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted, }; +use std::time::Duration as StdDuration; use wkt::Duration; /// A builder for a [TransactionRunner] for a read/write transaction. @@ -45,6 +46,7 @@ use wkt::Duration; pub struct TransactionRunnerBuilder { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunnerBuilder { @@ -52,9 +54,35 @@ impl TransactionRunnerBuilder { Self { builder: ReadWriteTransactionBuilder::new(client), retry_policy: Box::new(BasicTransactionRetryPolicy::default()), + timeout: None, } } + /// Sets the timeout for the entire transaction. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # use std::time::Duration; + /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let runner = db_client.read_write_transaction() + /// .with_transaction_timeout(Duration::from_secs(5)) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This timeout applies to the total time spent executing the transaction, including + /// all statements and automatic retries. Each individual RPC within the transaction + /// is automatically assigned a deadline derived from the remaining time of this + /// overall timeout. + pub fn with_transaction_timeout(mut self, timeout: StdDuration) -> Self { + self.timeout = Some(timeout); + self + } + /// Sets the isolation level for the transaction. /// /// # Example @@ -228,6 +256,7 @@ impl TransactionRunnerBuilder { Ok(TransactionRunner { builder: self.builder, retry_policy: self.retry_policy, + timeout: self.timeout, }) } } @@ -236,6 +265,7 @@ impl TransactionRunnerBuilder { pub struct TransactionRunner { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunner { @@ -271,13 +301,14 @@ impl TransactionRunner { let start_time = tokio::time::Instant::now(); let mut attempts: u32 = 0; let backoff = crate::transaction_retry_policy::default_retry_backoff(); + let deadline = self.timeout.map(|t| start_time + t); loop { attempts += 1; let mut current_tx_id = None; let attempt_result = async { - let transaction = self.builder.begin_transaction().await?; + let transaction = self.builder.begin_transaction(deadline).await?; current_tx_id = transaction.transaction_id().ok(); let result = match work(transaction.clone()).await { From dc297ec17e3eeb4ce6b392c26945c6f2a8112a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 25 Apr 2026 06:57:02 +0200 Subject: [PATCH 6/6] chore(spanner): use mockall for test policies --- src/spanner/src/client.rs | 6 +- src/spanner/src/result_set.rs | 102 +++++++++++++++++----------------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index 5969681590..ac65953e7e 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -252,7 +252,9 @@ mod tests { use crate::result_set::tests::adapt; use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; use google_cloud_gax::error::rpc::Code; + use google_cloud_gax::retry_state::RetryState; use google_cloud_test_macros::tokio_test_no_panics; use spanner_grpc_mock::google::rpc as mock_rpc; use spanner_grpc_mock::google::spanner::v1 as mock_v1; @@ -264,8 +266,8 @@ mod tests { mockall::mock! { #[derive(Debug)] BackoffPolicy {} - impl google_cloud_gax::backoff_policy::BackoffPolicy for BackoffPolicy { - fn on_failure(&self, state: &google_cloud_gax::retry_state::RetryState) -> std::time::Duration; + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; } } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index fcacf408d5..4564c1c40e 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -629,6 +629,9 @@ pub(crate) mod tests { use crate::read::ReadRequest; use gaxi::grpc::tonic::{Code as GrpcCode, Response}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_state::RetryState; + use google_cloud_test_macros::tokio_test_no_panics; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1 as spanner_v1; @@ -637,18 +640,13 @@ pub(crate) mod tests { PartialResultSet, ResultSetMetadata, Session, StructType, }; use spanner_grpc_mock::start; - use std::sync::Arc; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; - /// A backoff policy that always returns the same duration and that contains a counter that - /// can be used to verify that it was called. - #[derive(Debug)] - pub(crate) struct ConstantBackoff(pub(crate) Duration, pub(crate) Arc); - impl google_cloud_gax::backoff_policy::BackoffPolicy for ConstantBackoff { - fn on_failure(&self, _state: &RetryState) -> Duration { - self.1.fetch_add(1, Ordering::SeqCst); - self.0 + mockall::mock! { + #[derive(Debug)] + BackoffPolicy {} + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; } } @@ -996,12 +994,14 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let read_req = crate::read::ReadRequest::builder("table", vec!["Id", "Value"]) .with_keys(crate::key::KeySet::all()) - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - Arc::new(AtomicUsize::new(0)), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs: ResultSet = tx.execute_read(read_req).await?; @@ -1083,14 +1083,16 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let backoff_count = Arc::new(AtomicUsize::new(0)); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(1) + .returning(|_| Duration::from_nanos(1)); + let read_req = ReadRequest::builder("table", vec!["Id", "Value"]) .with_keys(KeySet::all()) .with_retry_policy(retry_policy) - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - backoff_count.clone(), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs: ResultSet = tx.execute_read(read_req).await?; @@ -1104,12 +1106,6 @@ pub(crate) mod tests { assert!(rs.next().await.is_none()); - assert_eq!( - backoff_count.load(Ordering::SeqCst), - 1, - "Backoff policy should have been called once" - ); - Ok(()) } @@ -1541,11 +1537,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let stmt = Statement::builder("SELECT 1") - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - Arc::new(AtomicUsize::new(0)), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs = tx.execute_query(stmt).await?; @@ -1754,11 +1752,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let stmt = Statement::builder("SELECT 1") - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - Arc::new(AtomicUsize::new(0)), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs = tx.execute_query(stmt).await?; @@ -1768,7 +1768,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn test_result_set_retry_under_limit_no_resume_token() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1835,11 +1835,13 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let stmt = Statement::builder("SELECT 1") - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - Arc::new(AtomicUsize::new(0)), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs = tx.execute_query(stmt).await?; @@ -1890,12 +1892,14 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let backoff_count = Arc::new(AtomicUsize::new(0)); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(10) + .returning(|_| Duration::from_nanos(1)); + let stmt = Statement::builder("SELECT 1") - .with_backoff_policy(ConstantBackoff( - Duration::from_nanos(1), - backoff_count.clone(), - )) + .with_backoff_policy(mock_backoff) .build(); let mut rs = tx.execute_query(stmt).await?; @@ -1910,16 +1914,10 @@ pub(crate) mod tests { err_str ); - assert_eq!( - backoff_count.load(Ordering::SeqCst), - 10, - "Backoff policy should have been called 10 times" - ); - Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -2014,7 +2012,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -2099,7 +2097,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status;