Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions src/spanner/src/batch_read_only_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ use crate::read_only_transaction::{
use crate::result_set::{ResultSet, StreamOperation};
use crate::statement::Statement;
use crate::timestamp_bound::TimestampBound;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::options::RequestOptions as GaxRequestOptions;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use serde::{Deserialize, Serialize};
use std::time::Duration;

/// A builder for [BatchReadOnlyTransaction].
///
Expand Down Expand Up @@ -172,6 +176,7 @@ impl BatchReadOnlyTransaction {

Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
}
})
.collect())
Expand Down Expand Up @@ -229,6 +234,7 @@ impl BatchReadOnlyTransaction {

Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
}
})
.collect())
Expand All @@ -241,6 +247,8 @@ impl BatchReadOnlyTransaction {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Partition {
pub(crate) inner: PartitionedOperation,
#[serde(skip)]
pub(crate) gax_options: GaxRequestOptions,
}

impl Partition {
Expand Down Expand Up @@ -270,6 +278,30 @@ impl Partition {
self
}

/// Sets the per-attempt timeout for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own attempt timeout.
pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
self.gax_options.set_attempt_timeout(timeout);
self
}

/// Sets the retry policy for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own retry policy.
pub fn with_retry_policy(mut self, policy: impl Into<RetryPolicyArg>) -> Self {
self.gax_options.set_retry_policy(policy);
self
}

/// Sets the backoff policy for this partition execution.
///
/// **Note:** This field is **not serialized**. Each host that executes a partition must set its own backoff policy.
pub fn with_backoff_policy(mut self, policy: impl Into<BackoffPolicyArg>) -> Self {
self.gax_options.set_backoff_policy(policy);
self
}

/// Executes this partition and returns a [ResultSet] that
/// contains the rows that belong to this partition.
///
Expand Down Expand Up @@ -320,18 +352,23 @@ impl Partition {
/// the database that the partitions belong to.
pub async fn execute(&self, client: &DatabaseClient) -> crate::Result<ResultSet> {
match &self.inner {
PartitionedOperation::Query(req) => Self::execute_query(client, req).await,
PartitionedOperation::Read(req) => Self::execute_read(client, req).await,
PartitionedOperation::Query(req) => {
Self::execute_query(client, req, self.gax_options.clone()).await
}
PartitionedOperation::Read(req) => {
Self::execute_read(client, req, self.gax_options.clone()).await
}
}
}

async fn execute_query(
client: &DatabaseClient,
req: &crate::model::ExecuteSqlRequest,
gax_options: GaxRequestOptions,
) -> crate::Result<ResultSet> {
let stream = client
.spanner
.execute_streaming_sql(req.clone(), crate::RequestOptions::default())
.execute_streaming_sql(req.clone(), gax_options.clone())
.send()
.await?;

Expand All @@ -345,16 +382,18 @@ impl Partition {
client.clone(),
req.session.clone(),
StreamOperation::Query(req.clone()),
gax_options,
))
}

async fn execute_read(
client: &DatabaseClient,
req: &crate::model::ReadRequest,
gax_options: GaxRequestOptions,
) -> crate::Result<ResultSet> {
let stream = client
.spanner
.streaming_read(req.clone(), crate::RequestOptions::default())
.streaming_read(req.clone(), gax_options.clone())
.send()
.await?;

Expand All @@ -368,6 +407,7 @@ impl Partition {
client.clone(),
req.session.clone(),
StreamOperation::Read(req.clone()),
gax_options,
))
}
}
Expand Down Expand Up @@ -401,6 +441,69 @@ pub(crate) mod tests {
assert_impl_all!(Partition: Send, Sync, Debug);
}

#[test]
fn serialize_partition_skips_gax_options() -> anyhow::Result<()> {
use std::time::Duration;

let req = crate::model::ExecuteSqlRequest::new()
.set_sql("SELECT 1")
.set_partition_token(b"token".to_vec());

let mut gax_options = GaxRequestOptions::default();
gax_options.set_attempt_timeout(Duration::from_secs(5));
let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options,
};

let serialized = serde_json::to_string(&partition)?;
let deserialized: Partition = serde_json::from_str(&serialized)?;

// Verify that gax_options was NOT preserved (it uses default, which is None timeout)
assert_eq!(*deserialized.gax_options.attempt_timeout(), None);

Ok(())
}

#[tokio::test]
async fn partition_execute_respects_options() -> anyhow::Result<()> {
use gaxi::grpc::tonic::Response;
use std::time::Duration;

let mut mock = create_session_mock();

mock.expect_execute_streaming_sql().once().returning(|req| {
let timeout = req.metadata().get("grpc-timeout");
assert!(timeout.is_some(), "Missing grpc-timeout header");
assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
});

let (db_client, _server) = setup_db_client(mock).await;

let req = crate::model::ExecuteSqlRequest::new()
.set_session("projects/p/instances/i/databases/d/sessions/123")
.set_transaction(crate::model::TransactionSelector {
selector: Some(Selector::Id(b"tx_id_1".to_vec().into())),
..Default::default()
})
.set_sql("SELECT 1")
.set_partition_token(b"token".to_vec());

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let partition = partition.with_attempt_timeout(Duration::from_secs(5));

let _result_set = partition.execute(&db_client).await?;

Ok(())
}

#[test]
fn serialize_partition_query() -> anyhow::Result<()> {
let req = crate::model::ExecuteSqlRequest::new()
Expand All @@ -416,6 +519,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let serialized = serde_json::to_string(&partition)?;
Expand Down Expand Up @@ -448,6 +552,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let serialized = serde_json::to_string(&partition)?;
Expand Down Expand Up @@ -498,6 +603,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.execute(&db_client).await?;
Expand Down Expand Up @@ -540,6 +646,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.execute(&db_client).await?;
Expand Down Expand Up @@ -700,6 +807,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Query(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.with_data_boost(true).execute(&db_client).await?;
Expand Down Expand Up @@ -732,6 +840,7 @@ pub(crate) mod tests {

let partition = Partition {
inner: PartitionedOperation::Read(req),
gax_options: GaxRequestOptions::default(),
};

let _result_set = partition.with_data_boost(true).execute(&db_client).await?;
Expand Down
Loading
Loading