Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 114 additions & 4 deletions src/spanner/src/batch_read_only_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ use crate::read_only_transaction::{
use crate::result_set::{ResultSet, StreamOperation};
use crate::statement::Statement;
use crate::timestamp_bound::TimestampBound;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use serde::{Deserialize, Serialize};
use std::time::Duration;

/// A builder for [BatchReadOnlyTransaction].
///
Expand Down Expand Up @@ -172,6 +176,7 @@ impl BatchReadOnlyTransaction {

Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
}
})
.collect())
Expand Down Expand Up @@ -229,6 +234,7 @@ impl BatchReadOnlyTransaction {

Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
}
})
.collect())
Expand All @@ -241,6 +247,8 @@ impl BatchReadOnlyTransaction {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Partition {
pub(crate) inner: PartitionedOperation,
#[serde(skip)]
pub(crate) gax_options: GaxRequestOptions,
}

impl Partition {
Expand Down Expand Up @@ -270,6 +278,30 @@ impl Partition {
self
}

/// Sets the per-attempt timeout for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own attempt timeout.
pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
self.gax_options.set_attempt_timeout(timeout);
self
}

/// Sets the retry policy for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own retry policy.
pub fn with_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
self.gax_options.set_retry_policy(policy);
self
}

/// Sets the backoff policy for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own backoff policy.
pub fn with_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
self.gax_options.set_backoff_policy(policy);
self
}

/// Executes this partition and returns a [ResultSet] that
/// contains the rows that belong to this partition.
///
Expand Down Expand Up @@ -320,18 +352,23 @@ impl Partition {
/// the database that the partitions belong to.
pub async fn execute(&self, client: &DatabaseClient) -> crate::Result<ResultSet> {
match &self.inner {
PartitionedOperation::Query(req) => Self::execute_query(client, req).await,
PartitionedOperation::Read(req) => Self::execute_read(client, req).await,
PartitionedOperation::Query(req) => {
Self::execute_query(client, req, self.gax_options.clone()).await
}
PartitionedOperation::Read(req) => {
Self::execute_read(client, req, self.gax_options.clone()).await
}
}
}

async fn execute_query(
client: &DatabaseClient,
req: &crate::model::ExecuteSqlRequest,
gax_options: GaxRequestOptions,
) -> crate::Result<ResultSet> {
let stream = client
.spanner
.execute_streaming_sql(req.clone(), crate::RequestOptions::default())
.execute_streaming_sql(req.clone(), gax_options.clone())
.send()
.await?;

Expand All @@ -345,16 +382,18 @@ impl Partition {
client.clone(),
req.session.clone(),
StreamOperation::Query(req.clone()),
gax_options,
))
}

async fn execute_read(
client: &DatabaseClient,
req: &crate::model::ReadRequest,
gax_options: GaxRequestOptions,
) -> crate::Result<ResultSet> {
let stream = client
.spanner
.streaming_read(req.clone(), crate::RequestOptions::default())
.streaming_read(req.clone(), gax_options.clone())
.send()
.await?;

Expand All @@ -368,6 +407,7 @@ impl Partition {
client.clone(),
req.session.clone(),
StreamOperation::Read(req.clone()),
gax_options,
))
}
}
Expand All @@ -387,6 +427,7 @@ pub(crate) mod tests {
use crate::model::{ExecuteSqlRequest, ReadRequest as GrpcReadRequest, TransactionSelector};
use crate::read_only_transaction::tests::{create_session_mock, setup_db_client};
use gaxi::grpc::tonic::Response;
use google_cloud_test_macros::tokio_test_no_panics;
use prost_types::Timestamp;
use spanner_grpc_mock::google::spanner::v1::{
Partition as MockPartition, PartitionResponse, Transaction,
Expand All @@ -401,6 +442,69 @@ pub(crate) mod tests {
assert_impl_all!(Partition: Send, Sync, Debug);
}

#[test]
fn serialize_partition_skips_gax_options() -> anyhow::Result<()> {
use std::time::Duration;

let req = crate::model::ExecuteSqlRequest::new()
.set_sql("SELECT 1")
.set_partition_token(b"token".to_vec());

let mut gax_options = GaxRequestOptions::default();
gax_options.set_attempt_timeout(Duration::from_secs(5));
let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options,
};

let serialized = serde_json::to_string(&partition)?;
let deserialized: Partition = serde_json::from_str(&serialized)?;

// Verify that gax_options was NOT preserved (it uses default, which is None timeout)
assert_eq!(*deserialized.gax_options.attempt_timeout(), None);

Ok(())
}

#[tokio_test_no_panics]
async fn partition_execute_respects_options() -> anyhow::Result<()> {
use gaxi::grpc::tonic::Response;
use std::time::Duration;

let mut mock = create_session_mock();

mock.expect_execute_streaming_sql().once().returning(|req| {
let timeout = req.metadata().get("grpc-timeout");
assert!(timeout.is_some(), "Missing grpc-timeout header");
assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
});

let (db_client, _server) = setup_db_client(mock).await;

let req = crate::model::ExecuteSqlRequest::new()
.set_session("projects/p/instances/i/databases/d/sessions/123")
.set_transaction(crate::model::TransactionSelector {
selector: Some(Selector::Id(b"tx_id_1".to_vec().into())),
..Default::default()
})
.set_sql("SELECT 1")
.set_partition_token(b"token".to_vec());

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let partition = partition.with_attempt_timeout(Duration::from_secs(5));

let _result_set = partition.execute(&db_client).await?;

Ok(())
}

#[test]
fn serialize_partition_query() -> anyhow::Result<()> {
let req = crate::model::ExecuteSqlRequest::new()
Expand All @@ -416,6 +520,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let serialized = serde_json::to_string(&partition)?;
Expand Down Expand Up @@ -448,6 +553,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let serialized = serde_json::to_string(&partition)?;
Expand Down Expand Up @@ -498,6 +604,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.execute(&db_client).await?;
Expand Down Expand Up @@ -540,6 +647,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.execute(&db_client).await?;
Expand Down Expand Up @@ -700,6 +808,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.with_data_boost(true).execute(&db_client).await?;
Expand Down Expand Up @@ -732,6 +841,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.with_data_boost(true).execute(&db_client).await?;
Expand Down
6 changes: 4 additions & 2 deletions src/spanner/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ mod tests {
use crate::result_set::tests::adapt;
use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status};
use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
use google_cloud_gax::backoff_policy::BackoffPolicy;
use google_cloud_gax::error::rpc::Code;
use google_cloud_gax::retry_state::RetryState;
use google_cloud_test_macros::tokio_test_no_panics;
use spanner_grpc_mock::google::rpc as mock_rpc;
use spanner_grpc_mock::google::spanner::v1 as mock_v1;
Expand All @@ -264,8 +266,8 @@ mod tests {
mockall::mock! {
#[derive(Debug)]
BackoffPolicy {}
impl google_cloud_gax::backoff_policy::BackoffPolicy for BackoffPolicy {
fn on_failure(&self, state: &google_cloud_gax::retry_state::RetryState) -> std::time::Duration;
impl BackoffPolicy for BackoffPolicy {
fn on_failure(&self, state: &RetryState) -> Duration;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/spanner/src/read_only_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ macro_rules! execute_stream_with_retry {
$self.client.clone(),
$self.session_name.clone(),
$operation_variant($request),
$gax_options,
))
}};
}
Expand Down
Loading
Loading