diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 788e4b22ef..12713f9507 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -21,7 +21,11 @@ use crate::read_only_transaction::{ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicyArg; use serde::{Deserialize, Serialize}; +use std::time::Duration; /// A builder for [BatchReadOnlyTransaction]. /// @@ -172,6 +176,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -229,6 +234,7 @@ impl BatchReadOnlyTransaction { Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), } }) .collect()) @@ -241,6 +247,8 @@ impl BatchReadOnlyTransaction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Partition { pub(crate) inner: PartitionedOperation, + #[serde(skip)] + pub(crate) gax_options: GaxRequestOptions, } impl Partition { @@ -270,6 +278,30 @@ impl Partition { self } + /// Sets the per-attempt timeout for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own attempt timeout. + pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self { + self.gax_options.set_attempt_timeout(timeout); + self + } + + /// Sets the retry policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own retry policy. + pub fn with_retry_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_retry_policy(policy); + self + } + + /// Sets the backoff policy for this partition execution. + /// + /// **Note:** This field is **not serialized**. Each host that executes a partition must set its own backoff policy. + pub fn with_backoff_policy(mut self, policy: impl Into) -> Self { + self.gax_options.set_backoff_policy(policy); + self + } + /// Executes this partition and returns a [ResultSet] that /// contains the rows that belong to this partition. /// @@ -320,18 +352,23 @@ impl Partition { /// the database that the partitions belong to. pub async fn execute(&self, client: &DatabaseClient) -> crate::Result { match &self.inner { - PartitionedOperation::Query(req) => Self::execute_query(client, req).await, - PartitionedOperation::Read(req) => Self::execute_read(client, req).await, + PartitionedOperation::Query(req) => { + Self::execute_query(client, req, self.gax_options.clone()).await + } + PartitionedOperation::Read(req) => { + Self::execute_read(client, req, self.gax_options.clone()).await + } } } async fn execute_query( client: &DatabaseClient, req: &crate::model::ExecuteSqlRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), gax_options.clone()) .send() .await?; @@ -345,16 +382,18 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), + gax_options, )) } async fn execute_read( client: &DatabaseClient, req: &crate::model::ReadRequest, + gax_options: GaxRequestOptions, ) -> crate::Result { let stream = client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), gax_options.clone()) .send() .await?; @@ -368,6 +407,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), + gax_options, )) } } @@ -387,6 +427,7 @@ pub(crate) mod tests { use crate::model::{ExecuteSqlRequest, ReadRequest as GrpcReadRequest, TransactionSelector}; use crate::read_only_transaction::tests::{create_session_mock, setup_db_client}; use gaxi::grpc::tonic::Response; + use google_cloud_test_macros::tokio_test_no_panics; use prost_types::Timestamp; use spanner_grpc_mock::google::spanner::v1::{ Partition as MockPartition, PartitionResponse, Transaction, @@ -401,6 +442,69 @@ pub(crate) mod tests { assert_impl_all!(Partition: Send, Sync, Debug); } + #[test] + fn serialize_partition_skips_gax_options() -> anyhow::Result<()> { + use std::time::Duration; + + let req = crate::model::ExecuteSqlRequest::new() + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let mut gax_options = GaxRequestOptions::default(); + gax_options.set_attempt_timeout(Duration::from_secs(5)); + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options, + }; + + let serialized = serde_json::to_string(&partition)?; + let deserialized: Partition = serde_json::from_str(&serialized)?; + + // Verify that gax_options was NOT preserved (it uses default, which is None timeout) + assert_eq!(*deserialized.gax_options.attempt_timeout(), None); + + Ok(()) + } + + #[tokio_test_no_panics] + async fn partition_execute_respects_options() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use std::time::Duration; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let timeout = req.metadata().get("grpc-timeout"); + assert!(timeout.is_some(), "Missing grpc-timeout header"); + assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros + + let (_, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::from(rx)) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let req = crate::model::ExecuteSqlRequest::new() + .set_session("projects/p/instances/i/databases/d/sessions/123") + .set_transaction(crate::model::TransactionSelector { + selector: Some(Selector::Id(b"tx_id_1".to_vec().into())), + ..Default::default() + }) + .set_sql("SELECT 1") + .set_partition_token(b"token".to_vec()); + + let partition = Partition { + inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), + }; + + let partition = partition.with_attempt_timeout(Duration::from_secs(5)); + + let _result_set = partition.execute(&db_client).await?; + + Ok(()) + } + #[test] fn serialize_partition_query() -> anyhow::Result<()> { let req = crate::model::ExecuteSqlRequest::new() @@ -416,6 +520,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -448,6 +553,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let serialized = serde_json::to_string(&partition)?; @@ -498,6 +604,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -540,6 +647,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.execute(&db_client).await?; @@ -700,6 +808,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Query(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; @@ -732,6 +841,7 @@ pub(crate) mod tests { let partition = Partition { inner: PartitionedOperation::Read(req), + gax_options: GaxRequestOptions::default(), }; let _result_set = partition.with_data_boost(true).execute(&db_client).await?; diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index 5969681590..ac65953e7e 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -252,7 +252,9 @@ mod tests { use crate::result_set::tests::adapt; use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; use google_cloud_gax::error::rpc::Code; + use google_cloud_gax::retry_state::RetryState; use google_cloud_test_macros::tokio_test_no_panics; use spanner_grpc_mock::google::rpc as mock_rpc; use spanner_grpc_mock::google::spanner::v1 as mock_v1; @@ -264,8 +266,8 @@ mod tests { mockall::mock! { #[derive(Debug)] BackoffPolicy {} - impl google_cloud_gax::backoff_policy::BackoffPolicy for BackoffPolicy { - fn on_failure(&self, state: &google_cloud_gax::retry_state::RetryState) -> std::time::Duration; + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; } } diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index a7f3a1356d..04e877e90f 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -610,6 +610,7 @@ macro_rules! execute_stream_with_retry { $self.client.clone(), $self.session_name.clone(), $operation_variant($request), + $gax_options, )) }}; } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 8832fb6209..4564c1c40e 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -23,9 +23,16 @@ use crate::row::Row; use crate::server_streaming::stream::PartialResultSetStream; use bytes::Bytes; use gaxi::prost::FromProto; -use google_cloud_gax::error::rpc::Code; +use google_cloud_gax::backoff_policy::BackoffPolicy; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; +use google_cloud_gax::retry_state::RetryState; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use tokio::time::sleep; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -64,6 +71,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + gax_options: GaxRequestOptions, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -95,7 +103,9 @@ impl ResultSet { client: DatabaseClient, session_name: String, operation: StreamOperation, + gax_options: GaxRequestOptions, ) -> Self { + let gax_options = Self::apply_defaults(gax_options); Self { stream: Some(stream), buffered_values: Vec::new(), @@ -114,9 +124,24 @@ impl ResultSet { retry_count: 0, transaction_selector, stats: None, + gax_options, } } + fn apply_defaults(mut gax_options: GaxRequestOptions) -> GaxRequestOptions { + if gax_options.retry_policy().is_none() { + gax_options.set_retry_policy(Aip194Strict.with_attempt_limit(10)); + } + if gax_options.backoff_policy().is_none() { + gax_options.set_backoff_policy(Self::default_backoff_policy()); + } + gax_options + } + + fn default_backoff_policy() -> Arc { + Arc::new(ExponentialBackoffBuilder::default().clamp()) + } + /// Returns the metadata of the result set. /// /// # Example @@ -290,6 +315,15 @@ impl ResultSet { // Clear the buffer and restart the stream using the last // resume_token that we have seen. self.partial_result_sets_buffer.clear(); + + // Apply backoff delay if policy is present + if let Some(policy) = self.gax_options.backoff_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + let delay = policy.on_failure(&state); + sleep(delay).await; + } + self.restart_stream().await?; return Ok(()); } @@ -461,7 +495,7 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -474,7 +508,7 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), self.gax_options.clone()) .send() .await?; self.stream = Some(stream); @@ -483,13 +517,17 @@ impl ResultSet { Ok(()) } - // TODO(#5185): Make the retry policy configurable. fn should_retry(&self, e: &crate::Error) -> bool { - if self.retry_count >= 10 { - return false; + if let Some(policy) = self.gax_options.retry_policy() { + let state = + RetryState::new(self.safe_to_retry).set_attempt_count(self.retry_count as u32); + + if let Some(status) = e.status() { + let gax_error = GaxError::service(status.clone()); + return policy.on_error(&state, gax_error).is_continue(); + } } - e.status() - .is_some_and(|status| status.code == Code::Unavailable) + false } /// Converts the [`ResultSet`] into a [`Stream`]. @@ -586,8 +624,14 @@ impl ResultSet { pub(crate) mod tests { use super::*; use crate::client::Spanner; - use gaxi::grpc::tonic::Response; + use crate::client::Statement; + use crate::key::KeySet; + use crate::read::ReadRequest; + use gaxi::grpc::tonic::{Code as GrpcCode, Response}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_gax::backoff_policy::BackoffPolicy; + use google_cloud_gax::retry_state::RetryState; + use google_cloud_test_macros::tokio_test_no_panics; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1 as spanner_v1; @@ -596,6 +640,15 @@ pub(crate) mod tests { PartialResultSet, ResultSetMetadata, Session, StructType, }; use spanner_grpc_mock::start; + use std::time::Duration; + + mockall::mock! { + #[derive(Debug)] + BackoffPolicy {} + impl BackoffPolicy for BackoffPolicy { + fn on_failure(&self, state: &RetryState) -> Duration; + } + } pub(crate) fn string_val(s: &str) -> Value { Value { @@ -863,6 +916,27 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_default_policies_applied() -> anyhow::Result<()> { + let rs = run_mock_query(vec![PartialResultSet { + metadata: metadata(2), + last: true, + ..Default::default() + }]) + .await; + + assert!( + rs.gax_options.retry_policy().is_some(), + "Default retry policy should be applied" + ); + assert!( + rs.gax_options.backoff_policy().is_some(), + "Default backoff policy should be applied" + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_retry_read_stream() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; @@ -920,8 +994,14 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + let read_req = crate::read::ReadRequest::builder("table", vec!["Id", "Value"]) .with_keys(crate::key::KeySet::all()) + .with_backoff_policy(mock_backoff) .build(); let mut rs: ResultSet = tx.execute_read(read_req).await?; @@ -936,6 +1016,99 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_custom_retry_policy() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + use google_cloud_gax::retry_policy::Aip194Strict; + use google_cloud_gax::retry_policy::RetryPolicyExt; + + // Extend the default retry policy to also retry on ResourceExhausted. + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Fail with RESOURCE_EXHAUSTED on first call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([ + Ok(PartialResultSet { + metadata: metadata(2), + values: vec![string_val("row1"), string_val("b")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::new(GrpcCode::ResourceExhausted, "Quota exceeded")), + ]); + Ok(Response::from(stream)) + }); + + // Succeed on second call + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = adapt([Ok(PartialResultSet { + values: vec![string_val("row2"), string_val("d")], + resume_token: b"token2".to_vec(), + last: true, + ..Default::default() + })]); + Ok(Response::from(stream)) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(1) + .returning(|_| Duration::from_nanos(1)); + + let read_req = ReadRequest::builder("table", vec!["Id", "Value"]) + .with_keys(KeySet::all()) + .with_retry_policy(retry_policy) + .with_backoff_policy(mock_backoff) + .build(); + + let mut rs: ResultSet = tx.execute_read(read_req).await?; + + let row1 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row1.raw_values()[0].0, string_val("row1")); + + // This next() call should trigger the retry because the previous stream ended with error! + let row2 = rs.next().await.expect("Stream ended unexpectedly")?; + assert_eq!(row2.raw_values()[0].0, string_val("row2")); + + assert!(rs.next().await.is_none()); + + Ok(()) + } + #[tokio::test] async fn test_result_set_one_row() { let mut rs = run_mock_query(vec![PartialResultSet { @@ -1364,7 +1537,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Stream ended unexpectedly")?; assert_eq!(row1.raw_values()[0].0, string_val("row1")); @@ -1571,7 +1752,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let row1 = rs.next().await.expect("Expected row1")?; assert_eq!(row1.raw_values()[0].0, string_val("row1_retry")); @@ -1579,7 +1768,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn test_result_set_retry_under_limit_no_resume_token() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1646,7 +1835,15 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; // Set max buffer size to 3 (so 2 messages is under the limit) rs.set_max_buffered_partial_result_sets(3); @@ -1695,7 +1892,16 @@ pub(crate) mod tests { let db_client = client.database_client("db").build().await?; let tx = db_client.single_use().build(); - let mut rs = tx.execute_query("SELECT 1").await?; + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(10) + .returning(|_| Duration::from_nanos(1)); + + let stmt = Statement::builder("SELECT 1") + .with_backoff_policy(mock_backoff) + .build(); + let mut rs = tx.execute_query(stmt).await?; let res = rs.next().await; assert!(res.is_some(), "Expected an error but got None"); @@ -1711,7 +1917,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1806,7 +2012,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; @@ -1891,7 +2097,7 @@ pub(crate) mod tests { Ok(()) } - #[tokio::test] + #[tokio_test_no_panics] async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status;