diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 39e29d4d41..724bff7069 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -46,12 +46,10 @@ class _BatchBase(_SessionWrapper): :param session: the session used to perform the commit """ - transaction_tag = None - _read_only = False - def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] + self.transaction_tag: str = None def _check_state(self): """Helper for :meth:`commit` et al. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 93d9c1a31c..f75f95b257 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -770,6 +770,23 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() + def session(self, labels=None, database_role=None): + """Deprecated. Factory to create a session for this database. + + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for the session. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: a session bound to this database. + """ + # If role is specified in param, then that role is used + # instead. + role = database_role or self._database_role + return Session(self, labels=labels, database_role=role) + def snapshot(self, **kw): """Return an object which wraps a snapshot. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 22bbe0e103..39afdac4cc 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -16,14 +16,17 @@ import functools import threading +from typing import Union + from google.protobuf.struct_pb2 import Struct -from google.cloud.spanner_v1 import ExecuteSqlRequest -from google.cloud.spanner_v1 import ReadRequest -from google.cloud.spanner_v1 import TransactionOptions -from google.cloud.spanner_v1 import TransactionSelector +from google.cloud.spanner_v1 import ExecuteSqlRequest, ResultSet, PartialResultSet from google.cloud.spanner_v1 import PartitionOptions from google.cloud.spanner_v1 import PartitionQueryRequest from google.cloud.spanner_v1 import PartitionReadRequest +from google.cloud.spanner_v1 import ReadRequest +from google.cloud.spanner_v1 import TransactionOptions +from google.cloud.spanner_v1 import TransactionSelector +from google.cloud.spanner_v1 import Transaction as Transaction from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable @@ -44,6 +47,7 @@ from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import MultiplexedSessionPrecommitToken _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES = ( "RST_STREAM", @@ -81,8 +85,8 @@ def _restart_on_unavailable( if both transaction_selector and transaction are passed, then transaction is given priority. """ - resume_token = b"" - item_buffer = [] + resume_token: bytes = b"" + item_buffer: list[ResultSet] = [] if transaction is not None: transaction_selector = transaction._make_txn_selector() @@ -105,20 +109,19 @@ def _restart_on_unavailable( metadata=metadata, ), MetricsCapture(): iterator = method(request=request, metadata=metadata) + + item: ResultSet for item in iterator: item_buffer.append(item) - # Setting the transaction id because the transaction begin was inlined for first rpc. - if ( - transaction is not None - and transaction._transaction_id is None - and item.metadata is not None - and item.metadata.transaction is not None - and item.metadata.transaction.id is not None - ): - transaction._transaction_id = item.metadata.transaction.id + + # Update the snapshot using the response. + if transaction is not None: + transaction._update_for_result_set_pb(item) + if item.resume_token: resume_token = item.resume_token break + except ServiceUnavailable: del item_buffer[:] with trace_call( @@ -201,12 +204,31 @@ class _SnapshotBase(_SessionWrapper): :param session: the session used to perform the commit """ - _multi_use = False _read_only: bool = True - _transaction_id = None - _read_request_count = 0 - _execute_sql_count = 0 - _lock = threading.Lock() + _multi_use: bool = False + + def __init__(self, session): + super().__init__(session) + + # Counts for execute SQL requests and total read requests (including + # execute SQL requests). Used to provide sequence numbers for + # :class:`google.cloud.spanner_v1.types.ExecuteSqlRequest` and to + # verify that single-use transactions are not used more than once, + # respectively. + self._execute_sql_request_count: int = 0 + self._total_read_request_count: int = 0 + + # Identifier for the transaction. + self._transaction_id: bytes = None + + # Precommit tokens are returned for transactions with + # multiplexed sessions. The precommit token with the + # highest sequence number is included in the commit request. + self._precommit_token: MultiplexedSessionPrecommitToken = None + + # Operations within a transaction can be performed using multiple + # threads, so we need to use a lock when updating the transaction. + self._lock: threading.Lock = threading.Lock() def _make_txn_selector(self): """Helper for :meth:`read` / :meth:`execute_sql`. @@ -317,11 +339,11 @@ def read( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ - if self._read_request_count > 0: + if self._total_read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") if self._transaction_id is None and self._read_only: - raise ValueError("Transaction ID pending.") + raise ValueError("Transaction has not begun") database = self._session._database api = database.spanner_api @@ -360,7 +382,7 @@ def read( directed_read_options=directed_read_options, ) - streaming_read = functools.partial( + streaming_read_method = functools.partial( api.streaming_read, request=request, metadata=metadata, @@ -368,26 +390,28 @@ def read( timeout=timeout, ) - trace_attributes = {"table_id": table, "columns": columns} - observability_options = getattr(database, "observability_options", None) - - get_streamed_result_set_args = { - "method": streaming_read, - "request": request, - "metadata": metadata, - "trace_attributes": trace_attributes, - "column_info": column_info, - "observability_options": observability_options, - "lazy_decode": lazy_decode, - } + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - return self._get_streamed_result_set(**get_streamed_result_set_args) + is_inline_begin = True + self._lock.acquire() - else: - return self._get_streamed_result_set(**get_streamed_result_set_args) + streamed_result_set = self._get_streamed_result_set( + method=streaming_read_method, + request=request, + metadata=metadata, + trace_attributes={"table_id": table, "columns": columns}, + column_info=column_info, + observability_options=getattr(database, "observability_options", None), + lazy_decode=lazy_decode, + ) + + if is_inline_begin: + self._lock.release() + + return streamed_result_set def execute_sql( self, @@ -506,11 +530,11 @@ def execute_sql( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ - if self._read_request_count > 0: + if self._total_read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") if self._transaction_id is None and self._read_only: - raise ValueError("Transaction ID pending.") + raise ValueError("Transaction has not begun") if params is not None: params_pb = Struct( @@ -555,7 +579,7 @@ def execute_sql( param_types=param_types, query_mode=query_mode, partition_token=partition, - seqno=self._execute_sql_count, + seqno=self._execute_sql_request_count, query_options=query_options, request_options=request_options, last_statement=last_statement, @@ -570,25 +594,28 @@ def execute_sql( timeout=timeout, ) - trace_attributes = {"db.statement": sql} - observability_options = getattr(database, "observability_options", None) - - get_streamed_result_set_args = { - "method": execute_streaming_sql_method, - "request": request, - "metadata": metadata, - "trace_attributes": trace_attributes, - "column_info": column_info, - "observability_options": observability_options, - "lazy_decode": lazy_decode, - } + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - return self._get_streamed_result_set(**get_streamed_result_set_args) - else: - return self._get_streamed_result_set(**get_streamed_result_set_args) + is_inline_begin = True + self._lock.acquire() + + streamed_result_set = self._get_streamed_result_set( + method=execute_streaming_sql_method, + request=request, + metadata=metadata, + trace_attributes={"db.statement": sql}, + column_info=column_info, + observability_options=getattr(database, "observability_options", None), + lazy_decode=lazy_decode, + ) + + if is_inline_begin: + self._lock.release() + + return streamed_result_set def partition_read( self, @@ -836,10 +863,10 @@ def _get_streamed_result_set( observability_options=observability_options, ) - self._read_request_count += 1 - if is_execute_sql_request: - self._execute_sql_count += 1 + self._execute_sql_request_count += 1 + + self._total_read_request_count += 1 streamed_result_set_args = { "response_iterator": iterator, @@ -847,11 +874,58 @@ def _get_streamed_result_set( "lazy_decode": lazy_decode, } - if self._multi_use: - streamed_result_set_args["source"] = self - return StreamedResultSet(**streamed_result_set_args) + def _update_for_result_set_pb( + self, result_set_pb: Union[ResultSet, PartialResultSet] + ) -> None: + """Updates the snapshot for the given result set. + + :type result_set_pb: :class:`~google.cloud.spanner_v1.ResultSet` or + :class:`~google.cloud.spanner_v1.PartialResultSet` + :param result_set_pb: The result set to update the snapshot with. + """ + + if result_set_pb.metadata and result_set_pb.metadata.transaction: + self._update_for_transaction_pb(result_set_pb.metadata.transaction) + + if result_set_pb.precommit_token: + self._update_for_precommit_token_pb(result_set_pb.precommit_token) + + def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: + """Updates the snapshot for the given transaction. + + :type transaction_pb: :class:`~google.cloud.spanner_v1.Transaction` + :param transaction_pb: The transaction to update the snapshot with. + """ + + # The transaction ID should only be updated when the transaction is + # begun: either explicitly with a begin transaction request, or implicitly + # with read, execute SQL, batch update, or execute update requests. The + # caller is responsible for locking until the transaction ID is updated. + if self._transaction_id is None and transaction_pb.id: + self._transaction_id = transaction_pb.id + + if transaction_pb.precommit_token: + self._update_for_precommit_token_pb(transaction_pb.precommit_token) + + def _update_for_precommit_token_pb( + self, precommit_token_pb: MultiplexedSessionPrecommitToken + ) -> None: + """Updates the snapshot for the given multiplexed session precommit token. + + :type precommit_token_pb: :class:`~google.cloud.spanner_v1.MultiplexedSessionPrecommitToken` + :param precommit_token_pb: The multiplexed session precommit token to update the snapshot with. + """ + + # Because multiple threads can be used to perform operations within a + # transaction, we need to use a lock when updating the precommit token. + with self._lock: + if self._precommit_token is None or ( + precommit_token_pb.seq_num > self._precommit_token.seq_num + ): + self._precommit_token = precommit_token_pb + class Snapshot(_SnapshotBase): """Allow a set of reads / SQL statements with shared staleness. @@ -918,6 +992,7 @@ def __init__( self._min_read_timestamp = min_read_timestamp self._max_staleness = max_staleness self._exact_staleness = exact_staleness + self._multi_use = multi_use self._transaction_id = transaction_id @@ -953,7 +1028,7 @@ def _make_txn_selector(self): else: return TransactionSelector(single_use=options) - def begin(self): + def begin(self) -> bytes: """Begin a read-only transaction on the database. :rtype: bytes @@ -968,7 +1043,7 @@ def begin(self): if self._transaction_id is not None: raise ValueError("Read-only transaction already begun") - if self._read_request_count > 0: + if self._total_read_request_count > 0: raise ValueError("Read-only transaction already pending") database = self._session._database @@ -991,10 +1066,22 @@ def begin(self): options=txn_selector.begin, metadata=metadata, ) - response = _retry( + transaction_pb: Transaction = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) - self._transaction_id = response.id - self._transaction_read_timestamp = response.read_timestamp + + self._update_for_transaction_pb(transaction_pb) return self._transaction_id + + def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: + """Updates the snapshot for the given transaction. + + :type transaction_pb: :class:`~google.cloud.spanner_v1.Transaction` + :param transaction_pb: The transaction to update the snapshot with. + """ + + super(Snapshot, self)._update_for_transaction_pb(transaction_pb) + + if transaction_pb.read_timestamp is not None: + self._transaction_read_timestamp = transaction_pb.read_timestamp diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index 5de843e103..a39d30b2ed 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -13,7 +13,6 @@ # limitations under the License. """Wrapper for streaming results.""" - from google.cloud import exceptions from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -34,7 +33,7 @@ class StreamedResultSet(object): instances. :type source: :class:`~google.cloud.spanner_v1.snapshot.Snapshot` - :param source: Snapshot from which the result set was fetched. + :param source: Deprecated. Snapshot from which the result set was fetched. """ def __init__( @@ -50,7 +49,6 @@ def __init__( self._stats = None # Until set from last PRS self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value - self._source = source # Source snapshot self._column_info = column_info # Column information self._field_decoders = None self._lazy_decode = lazy_decode # Return protobuf values @@ -141,11 +139,7 @@ def _consume_next(self): response_pb = PartialResultSet.pb(response) if self._metadata is None: # first response - metadata = self._metadata = response_pb.metadata - - source = self._source - if source is not None and source._transaction_id is None: - source._transaction_id = metadata.transaction.id + self._metadata = response_pb.metadata if response_pb.HasField("stats"): # last response self._stats = response.stats diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 2f52aaa144..c72b940897 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -14,7 +14,6 @@ """Spanner read-write transaction support.""" import functools -import threading from google.protobuf.struct_pb2 import Struct from typing import Optional @@ -27,6 +26,12 @@ _check_rst_stream_error, _merge_Transaction_Options, ) + +from google.cloud.spanner_v1 import ( + Transaction as TransactionProto, + ExecuteBatchDmlResponse, + ResultSet, +) from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteSqlRequest @@ -52,23 +57,24 @@ class Transaction(_SnapshotBase, _BatchBase): :raises ValueError: if session has an existing transaction """ - committed = None - """Timestamp at which the transaction was successfully committed.""" - rolled_back = False - commit_stats = None - _multi_use = True - _execute_sql_count = 0 - _lock = threading.Lock() - _read_only = False exclude_txn_from_change_streams = False isolation_level = TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED + # Override defaults from _SnapshotBase. + _multi_use = True + _read_only = False + def __init__(self, session): if session._transaction is not None: raise ValueError("Session has existing transaction.") super(Transaction, self).__init__(session) + self.committed = None + """Timestamp at which the transaction was successfully committed.""" + self.rolled_back = False + self.commit_stats = None + def _check_state(self): """Helper for :meth:`commit` et al. @@ -141,7 +147,7 @@ def _execute_request( return response - def begin(self): + def begin(self) -> bytes: """Begin a transaction on the database. :rtype: bytes @@ -188,19 +194,20 @@ def begin(self): metadata=metadata, ) - def beforeNextRetry(nthRetry, delayInSeconds): + def before_next_retry(nth_retry, delay_in_seconds): add_span_event( span, "Transaction Begin Attempt Failed. Retrying", - {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, ) - response = _retry( + transaction_pb: TransactionProto = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, - beforeNextRetry=beforeNextRetry, + beforeNextRetry=before_next_retry, ) - self._transaction_id = response.id + + self._update_for_transaction_pb(transaction_pb) return self._transaction_id def rollback(self): @@ -302,6 +309,7 @@ def commit( return_commit_stats=return_commit_stats, max_commit_delay=max_commit_delay, request_options=request_options, + precommit_token=self._precommit_token, ) add_span_event(span, "Starting Commit") @@ -312,21 +320,25 @@ def commit( metadata=metadata, ) - def beforeNextRetry(nthRetry, delayInSeconds): + def before_next_retry(nth_retry, delay_in_seconds): add_span_event( span, "Transaction Commit Attempt Failed. Retrying", - {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, ) response = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, - beforeNextRetry=beforeNextRetry, + beforeNextRetry=before_next_retry, ) add_span_event(span, "Commit Done") + # TODO multiplexed + # Retry commit if the response contains a MultiplexedSessionRetry entry. + # Will require refactoring the commit method to the _BatchBase class. + self.committed = response.commit_timestamp if return_commit_stats: self.commit_stats = response.commit_stats @@ -436,9 +448,9 @@ def execute_update( ) api = database.spanner_api - seqno, self._execute_sql_count = ( - self._execute_sql_count, - self._execute_sql_count + 1, + execute_sql_request_count, self._execute_sql_request_count = ( + self._execute_sql_request_count, + self._execute_sql_request_count + 1, ) # Query-level options have higher precedence than client-level and @@ -457,14 +469,24 @@ def execute_update( trace_attributes = {"db.statement": dml} + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False + + if self._transaction_id is None: + is_inline_begin = True + self._lock.acquire() + + transaction_selector_pb = self._make_txn_selector() request = ExecuteSqlRequest( session=self._session.name, + transaction=transaction_selector_pb, sql=dml, params=params_pb, param_types=param_types, query_mode=query_mode, query_options=query_options, - seqno=seqno, + seqno=execute_sql_request_count, request_options=request_options, last_statement=last_statement, ) @@ -477,38 +499,22 @@ def execute_update( timeout=timeout, ) - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - response = self._execute_request( - method, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_update", - self._session, - trace_attributes, - observability_options=observability_options, - ) - # Setting the transaction id because the transaction begin was inlined for first rpc. - if ( - self._transaction_id is None - and response is not None - and response.metadata is not None - and response.metadata.transaction is not None - ): - self._transaction_id = response.metadata.transaction.id - else: - response = self._execute_request( - method, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_update", - self._session, - trace_attributes, - observability_options=observability_options, - ) + result_set_pb: ResultSet = self._execute_request( + method, + request, + metadata, + f"CloudSpanner.{type(self).__name__}.execute_update", + self._session, + trace_attributes, + observability_options=observability_options, + ) + + self._update_for_result_set_pb(result_set_pb) - return response.stats.row_count_exact + if is_inline_begin: + self._lock.release() + + return result_set_pb.stats.row_count_exact def batch_update( self, @@ -588,9 +594,9 @@ def batch_update( api = database.spanner_api observability_options = getattr(database, "observability_options", None) - seqno, self._execute_sql_count = ( - self._execute_sql_count, - self._execute_sql_count + 1, + execute_sql_request_count, self._execute_sql_request_count = ( + self._execute_sql_request_count, + self._execute_sql_request_count + 1, ) if request_options is None: @@ -603,10 +609,21 @@ def batch_update( # Get just the queries from the DML statement batch "db.statement": ";".join([statement.sql for statement in parsed]) } + + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False + + if self._transaction_id is None: + is_inline_begin = True + self._lock.acquire() + + transaction_selector_pb = self._make_txn_selector() request = ExecuteBatchDmlRequest( session=self._session.name, + transaction=transaction_selector_pb, statements=parsed, - seqno=seqno, + seqno=execute_sql_request_count, request_options=request_options, last_statements=last_statement, ) @@ -619,43 +636,41 @@ def batch_update( timeout=timeout, ) - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - response = self._execute_request( - method, - request, - metadata, - "CloudSpanner.DMLTransaction", - self._session, - trace_attributes, - observability_options=observability_options, - ) - # Setting the transaction id because the transaction begin was inlined for first rpc. - for result_set in response.result_sets: - if ( - self._transaction_id is None - and result_set.metadata is not None - and result_set.metadata.transaction is not None - ): - self._transaction_id = result_set.metadata.transaction.id - break - else: - response = self._execute_request( - method, - request, - metadata, - "CloudSpanner.DMLTransaction", - self._session, - trace_attributes, - observability_options=observability_options, - ) + response_pb: ExecuteBatchDmlResponse = self._execute_request( + method, + request, + metadata, + "CloudSpanner.DMLTransaction", + self._session, + trace_attributes, + observability_options=observability_options, + ) + + self._update_for_execute_batch_dml_response_pb(response_pb) + + if is_inline_begin: + self._lock.release() row_counts = [ - result_set.stats.row_count_exact for result_set in response.result_sets + result_set.stats.row_count_exact for result_set in response_pb.result_sets ] - return response.status, row_counts + return response_pb.status, row_counts + + def _update_for_execute_batch_dml_response_pb( + self, response_pb: ExecuteBatchDmlResponse + ) -> None: + """Update the transaction for the given execute batch DML response. + + :type response_pb: :class:`~google.cloud.spanner_v1.types.ExecuteBatchDmlResponse` + :param response_pb: The execute batch DML response to update the transaction with. + """ + if response_pb.precommit_token: + self._update_for_precommit_token_pb(response_pb.precommit_token) + + # Only the first result set contains the result set metadata. + if len(response_pb.result_sets) > 0: + self._update_for_result_set_pb(response_pb.result_sets[0]) def __enter__(self): """Begin ``with`` block.""" diff --git a/tests/_builders.py b/tests/_builders.py index c07d003c19..37325a872c 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -12,60 +12,162 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Default identifiers. -PROJECT_ID = "project-id" -INSTANCE_ID = "instance-id" -DATABASE_ID = "database-id" -SESSION_ID = "session-id" -TRANSACTION_ID = b"transaction-id" +# Default identifiers and names. These are used to populate required +# attributes, but tests should not depend on them. If a test requires +# a specific identifier or name, it should set it explicitly. +_PROJECT_ID = "default-project-id" +_INSTANCE_ID = "default-instance-id" +_DATABASE_ID = "default-database-id" +_SESSION_ID = "default-session-id" +_TRANSACTION_ID = b"default-transaction-id" +_PRECOMMIT_TOKEN = b"default-precommit-token" +_PRECOMMIT_SEQ_NUM = -1 + +_PROJECT_NAME = "projects/" + _PROJECT_ID +_INSTANCE_NAME = _PROJECT_NAME + "/instances/" + _INSTANCE_ID +_DATABASE_NAME = _INSTANCE_NAME + "/databases/" + _DATABASE_ID +_SESSION_NAME = _DATABASE_NAME + "/sessions/" + _SESSION_ID + +# Protocol buffers +# ---------------- + + +def build_precommit_token_pb(**kwargs): + """Builds and returns a precommit token protocol buffer using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.types import MultiplexedSessionPrecommitToken + + return MultiplexedSessionPrecommitToken(**kwargs) + + +def build_session_pb(**kwargs): + """Builds and returns a session protocol buffer using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.types import Session + + if "name" not in kwargs: + kwargs["name"] = _SESSION_NAME + + return Session(**kwargs) + + +def build_result_set_pb(**kwargs): + """Builds and returns a result set protocol buffer using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.types import ResultSet -# Default names. -INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID -DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID -SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + if "metadata" not in kwargs or isinstance(kwargs["metadata"], dict): + metadata_args = kwargs.pop("metadata", {}) + kwargs["metadata"] = build_result_set_metadata_pb(**metadata_args) + if "precommit_token" not in kwargs or isinstance(kwargs["precommit_token"], dict): + precommit_token_args = kwargs.pop("precommit_token", {}) + kwargs["precommit_token"] = build_precommit_token_pb(**precommit_token_args) + + return ResultSet(**kwargs) + + +def build_result_set_metadata_pb(**kwargs): + """Builds and returns a result set metadata protocol buffer using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.types import ResultSetMetadata + + if "transaction" not in kwargs or isinstance(kwargs["transaction"], dict): + transaction_args = kwargs.pop("transaction", {}) + kwargs["transaction"] = build_transaction_pb(**transaction_args) + + return ResultSetMetadata(**kwargs) + + +def build_transaction_pb(**kwargs): + """Builds and returns a transaction protocol buffer using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.types import Transaction -def build_database(**database_kwargs): + if "id" not in kwargs: + kwargs["id"] = _TRANSACTION_ID + + return Transaction(**kwargs) + + +# Client objects +# -------------- + + +def build_database(**kwargs): """Builds and returns a database for testing using the given arguments. If a required argument is not provided, a default value will be used.""" - from google.cloud.spanner_v1 import Session as SessionProto from google.cloud.spanner_v1 import SpannerClient - from google.cloud.spanner_v1 import Transaction as TransactionProto from google.cloud.spanner_v1.database import Database from mock.mock import create_autospec - instance = _build_instance() - database = Database(database_id=DATABASE_ID, instance=instance) + if "database_id" not in kwargs: + kwargs["database_id"] = _DATABASE_ID + + if "instance" not in kwargs or isinstance(kwargs["instance"], dict): + instance_args = kwargs.pop("instance", {}) + kwargs["instance"] = _build_instance(**instance_args) + + database = Database(**kwargs) # Mock API calls. Callers can override this to test specific behaviours. api = database._spanner_api = create_autospec(SpannerClient, instance=True) - api.create_session.return_value = SessionProto(name=SESSION_NAME) - api.begin_transaction.return_value = TransactionProto(id=TRANSACTION_ID) + api.create_session.return_value = build_session_pb() + api.begin_transaction.return_value = build_transaction_pb() return database -def build_session(**session_kwargs): +def build_session(**kwargs): """Builds and returns a session for testing using the given arguments. If a required argument is not provided, a default value will be used.""" from google.cloud.spanner_v1.session import Session - if not session_kwargs.get("database"): - session_kwargs["database"] = build_database() + if "database" not in kwargs or isinstance(kwargs["database"], dict): + database_args = kwargs.pop("database", {}) + kwargs["database"] = build_database(**database_args) + + return Session(**kwargs) + - return Session(**session_kwargs) +def build_transaction(**kwargs): + """Builds and returns a transaction for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + + session = ( + build_session(**kwargs.pop("session", {})) + if "session" not in kwargs + else kwargs["session"] + ) + + # Session must be created before building transaction. + if session.session_id is None: + session.create() + return session.transaction() -def _build_client(): - """Builds and returns a client for testing.""" + +def _build_client(**kwargs): + """Builds and returns a client for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" from google.cloud.spanner_v1 import Client - return Client(project=PROJECT_ID) + if "project" not in kwargs: + kwargs["project"] = _PROJECT_ID + return Client(**kwargs) -def _build_instance(**instance_kwargs): - """Builds and returns an instance for testing.""" + +def _build_instance(**kwargs): + """Builds and returns an instance for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" from google.cloud.spanner_v1.instance import Instance - client = _build_client() - return Instance(instance_id=INSTANCE_ID, client=client) + if "instance_id" not in kwargs: + kwargs["instance_id"] = _INSTANCE_ID + + if "client" not in kwargs or isinstance(kwargs["client"], dict): + client_args = kwargs.pop("client", {}) + kwargs["client"] = _build_client(**client_args) + + return Instance(**kwargs) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 56d6223b3e..14d3a3948d 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -31,10 +31,15 @@ DirectedReadOptions, RequestOptions, ) +from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.snapshot import Snapshot -from tests._builders import build_database, build_session, TRANSACTION_ID, SESSION_ID -from tests._helpers import enable_multiplexed_sessions +from tests._builders import ( + build_database, + build_session, + build_session_pb, + build_transaction_pb, +) DML_WO_PARAM = """ DELETE FROM citizens @@ -82,6 +87,7 @@ class _BaseTest(unittest.TestCase): SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID + TRANSACTION_ID = b"transaction-id" TRANSACTION_TAG = "transaction-tag" DATABASE_ROLE = "dummy-role" @@ -1273,8 +1279,6 @@ def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self): ) def test_execute_partitioned_dml_not_implemented_error_multiplexed(self): - enable_multiplexed_sessions() - database = build_database() database.spanner_api.begin_transaction.side_effect = NotImplementedError( "Transaction type partitioned_dml not supported with multiplexed sessions" @@ -1286,6 +1290,25 @@ def test_execute_partitioned_dml_not_implemented_error_multiplexed(self): session_options = database.session_options self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + def test_session_factory_defaults(self): + database = build_database() + session = database.session() + + self.assertIsInstance(session, Session) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.labels, {}) + + def test_session_factory_w_labels(self): + database = build_database() + labels = {"foo": "bar"} + session = database.session(labels=labels) + + self.assertIsInstance(session, Session) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.labels, labels) + def test_snapshot_defaults(self): from google.cloud.spanner_v1.database import SnapshotCheckout @@ -2151,77 +2174,96 @@ def test_from_dict(self): batch_txn = BatchSnapshot.from_dict( database, { - "transaction_id": TRANSACTION_ID, - "session_id": SESSION_ID, + "transaction_id": self.TRANSACTION_ID, + "session_id": self.SESSION_ID, }, ) self.assertIs(batch_txn._database, database) - self.assertIs(batch_txn._transaction_id, TRANSACTION_ID) - self.assertIs(batch_txn._session_id, SESSION_ID) + self.assertIs(batch_txn._transaction_id, self.TRANSACTION_ID) + self.assertIs(batch_txn._session_id, self.SESSION_ID) session = batch_txn._get_session() - self.assertEqual(session._session_id, SESSION_ID) + self.assertEqual(session._session_id, self.SESSION_ID) snapshot = batch_txn._get_snapshot() - self.assertEqual(snapshot._transaction_id, TRANSACTION_ID) + self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) database.spanner_api.begin_transaction.assert_not_called() def test_to_dict(self): database = build_database() batch_txn = BatchSnapshot(database) + api = database.spanner_api + api.create_session.return_value = build_session_pb(name=self.SESSION_NAME) + api.begin_transaction.return_value = build_transaction_pb( + id=self.TRANSACTION_ID + ) + self.assertEqual( batch_txn.to_dict(), { - "transaction_id": TRANSACTION_ID, - "session_id": SESSION_ID, + "transaction_id": self.TRANSACTION_ID, + "session_id": self.SESSION_ID, }, ) def test__get_snapshot_already(self): database = build_database() batch_txn = BatchSnapshot(database) - snapshot_1 = batch_txn._get_snapshot() + snapshot_1 = batch_txn._get_snapshot() snapshot_2 = batch_txn._get_snapshot() + self.assertEqual(snapshot_1, snapshot_2) database.spanner_api.begin_transaction.assert_called_once() def test__get_snapshot_new_wo_staleness(self): database = build_database() batch_txn = BatchSnapshot(database) + + begin_transaction = database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=self.TRANSACTION_ID) + snapshot = batch_txn._get_snapshot() self.assertIsNone(snapshot._read_timestamp) self.assertIsNone(snapshot._exact_staleness) self.assertTrue(snapshot._multi_use) - self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) - database.spanner_api.begin_transaction.assert_called_once() + self.assertEqual(self.TRANSACTION_ID, snapshot._transaction_id) + begin_transaction.assert_called_once() def test__get_snapshot_w_read_timestamp(self): database = build_database() timestamp = self._make_timestamp() batch_txn = BatchSnapshot(database, read_timestamp=timestamp) + + begin_transaction = database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=self.TRANSACTION_ID) + snapshot = batch_txn._get_snapshot() self.assertEqual(timestamp, snapshot._read_timestamp) self.assertIsNone(snapshot._exact_staleness) self.assertTrue(snapshot._multi_use) - self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) - database.spanner_api.begin_transaction.assert_called_once() + self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) + begin_transaction.assert_called_once() def test__get_snapshot_w_exact_staleness(self): database = build_database() duration = self._make_duration() batch_txn = BatchSnapshot(database, exact_staleness=duration) + + begin_transaction = database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=self.TRANSACTION_ID) + snapshot = batch_txn._get_snapshot() self.assertIsNone(snapshot._read_timestamp) self.assertEqual(duration, snapshot._exact_staleness) self.assertTrue(snapshot._multi_use) - self.assertEqual(TRANSACTION_ID, snapshot._transaction_id) - database.spanner_api.begin_transaction.assert_called_once() + self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) + begin_transaction.assert_called_once() def test_read(self): keyset = self._make_keyset() diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index bb2695553b..7b25963727 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -24,7 +24,7 @@ from opentelemetry import metrics pytest.importorskip("opentelemetry") -# Skip if semconv attributes are not present, as tracing wont' be enabled either +# Skip if semconv attributes are not present, as tracing won't be enabled either # pytest.importorskip("opentelemetry.semconv.attributes.otel_attributes") diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 0e900acc37..b3afad99e4 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from google.api_core import gapic_v1 +from google.api_core.retry import Retry import mock from google.cloud.spanner_v1 import ( @@ -20,18 +22,20 @@ SpannerClient, KeySet, ) +from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.session_options import TransactionType +from tests._builders import ( + build_precommit_token_pb, + build_result_set_metadata_pb, + build_session, +) from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, StatusCode, HAS_OPENTELEMETRY_INSTALLED, - enrich_with_otel_scope, - enable_multiplexed_sessions, ) from google.cloud.spanner_v1.param_types import INT64 -from google.api_core.retry import Retry -from tests._builders import build_session TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -45,19 +49,8 @@ SELECT image_name FROM images WHERE @bytes IN image_data""" PARAMS_WITH_BYTES = {"bytes": b"FACEDACE"} RESUME_TOKEN = b"DEADBEEF" -TXN_ID = b"DEAFBEAD" SECONDS = 3 MICROS = 123456 -BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "testing", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", -} -enrich_with_otel_scope(BASE_ATTRIBUTES) DIRECTED_READ_OPTIONS = { "include_replicas": { @@ -80,6 +73,11 @@ }, } +TRANSACTION_ID = b"transaction-id" + +PRECOMMIT_TOKEN_1 = build_precommit_token_pb(precommit_token=b"1", seq_num=1) +PRECOMMIT_TOKEN_2 = build_precommit_token_pb(precommit_token=b"2", seq_num=2) + def _makeTimestamp(): import datetime @@ -96,9 +94,6 @@ def _getTargetClass(self): def _makeDerived(self, session): class _Derived(self._getTargetClass()): - _transaction_id = None - _multi_use = False - def _make_txn_selector(self): from google.cloud.spanner_v1 import ( TransactionOptions, @@ -146,7 +141,8 @@ def _make_item(self, value, resume_token=b"", metadata=None): value=value, resume_token=resume_token, metadata=metadata, - spec=["value", "resume_token", "metadata"], + precommit_token=build_precommit_token_pb(), + spec=["value", "resume_token", "metadata", "precommit_token"], ) def test_iteration_w_empty_raw(self): @@ -299,7 +295,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): fail_after=True, error=InternalServerError( "Received unexpected EOS on DATA frame from server" - ) + ), ) after = _MockIterator(*LAST) request = mock.Mock(test="test", spec=["test", "resume_token"]) @@ -418,15 +414,9 @@ def test_iteration_w_raw_raising_unavailable_w_multiuse(self): def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): from google.api_core.exceptions import ServiceUnavailable + from google.cloud.spanner_v1 import ReadRequest - from google.cloud.spanner_v1 import ResultSetMetadata - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - ReadRequest, - ) - - transaction_pb = TransactionPB(id=TXN_ID) - metadata_pb = ResultSetMetadata(transaction=transaction_pb) + metadata_pb = build_result_set_metadata_pb(transaction={"id": TRANSACTION_ID}) FIRST = ( self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN, metadata=metadata_pb), @@ -449,11 +439,13 @@ def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, RESUME_TOKEN) + + transaction_id_string = TRANSACTION_ID.decode("utf-8") transaction_id_selector_count = sum( [ 1 for args in restart.call_args_list - if 'id: "DEAFBEAD"' in args.kwargs.__str__() + if f'id: "{transaction_id_string}"' in args.kwargs.__str__() ] ) @@ -471,7 +463,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): fail_after=True, error=InternalServerError( "Received unexpected EOS on DATA frame from server" - ) + ), ) after = _MockIterator(*SECOND) request = mock.Mock(test="test", spec=["test", "resume_token"]) @@ -521,7 +513,9 @@ def test_iteration_w_span_creation(self): derived, restart, request, name, _Session(_Database()), extra_atts ) self.assertEqual(list(resumable), []) - self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1)) + self.assertSpanAttributes( + name, attributes=dict(_build_base_attributes(database), test_att=1) + ) def test_iteration_w_multiple_span_creation(self): from google.api_core.exceptions import ServiceUnavailable @@ -554,18 +548,12 @@ def test_iteration_w_multiple_span_creation(self): self.assertEqual(span.name, name) self.assertEqual( dict(span.attributes), - enrich_with_otel_scope(BASE_ATTRIBUTES), + _build_base_attributes(database), ) class Test_SnapshotBase(OpenTelemetryBase): - PROJECT_ID = "project-id" - INSTANCE_ID = "instance-id" - INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID - DATABASE_ID = "database-id" - DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session-id" - SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID def _getTargetClass(self): from google.cloud.spanner_v1.snapshot import _SnapshotBase @@ -604,7 +592,8 @@ def test_ctor(self): session = _Session() base = self._make_one(session) self.assertIs(base._session, session) - self.assertEqual(base._execute_sql_count, 0) + self.assertEqual(base._execute_sql_request_count, 0) + self.assertEqual(base._total_read_request_count, 0) self.assertNoSpans() @@ -615,8 +604,6 @@ def test__make_txn_selector_virtual(self): base._make_txn_selector() def test_read_partitioned_not_implemented_for_multiplexed(self): - enable_multiplexed_sessions() - database = ( self._build_database_with_partitioned_not_implemented_for_multiplexed() ) @@ -647,7 +634,9 @@ def test_read_other_error(self): "CloudSpanner._Derived.read", status=StatusCode.ERROR, attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) + _build_base_attributes(database), + table_id=TABLE_NAME, + columns=tuple(COLUMNS), ), ) @@ -662,11 +651,11 @@ def _read_helper( request_options=None, directed_read_options=None, directed_read_options_at_client_level=None, + use_multiplexed=False, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( PartialResultSet, - ResultSetMetadata, ResultSetStats, ) from google.cloud.spanner_v1 import ( @@ -678,6 +667,31 @@ def _read_helper( from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1._helpers import _make_value_pb + # [A] Build derived + # ----------------- + + session = build_session( + database={ + "instance": { + "client": { + "directed_read_options": directed_read_options_at_client_level + } + } + } + ) + + session._session_id = self.SESSION_ID + + derived = self._makeDerived(session) + derived._multi_use = multi_use + derived._total_read_request_count = count + + if not first: + derived._transaction_id = TRANSACTION_ID + + # [B] Build results + # ----------------- + VALUES = [["bharney", 31], ["phred", 32]] VALUE_PBS = [[_make_value_pb(item) for item in row] for row in VALUES] struct_type_pb = StructType( @@ -686,31 +700,46 @@ def _read_helper( StructType.Field(name="age", type_=Type(code=TypeCode.INT64)), ] ) - metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + + # If the transaction had not already begun, the first result + # set will include metadata with information about the transaction. + metadata_pb = build_result_set_metadata_pb( + row_type=struct_type_pb, + transaction={"id": TRANSACTION_ID} if first else None, + ) + stats_pb = ResultSetStats( query_stats=Struct(fields={"rows_returned": _make_value_pb(2)}) ) - result_sets = [ - PartialResultSet(metadata=metadata_pb), - PartialResultSet(stats=stats_pb), - ] + + # Precommit tokens will be included in the result sets if the transaction is on + # a multiplexed session. Precommit tokens may be returned out of order. + partial_result_set_1_args = {"metadata": metadata_pb} + if use_multiplexed: + partial_result_set_1_args["precommit_token"] = PRECOMMIT_TOKEN_2 + partial_result_set_1 = PartialResultSet(**partial_result_set_1_args) + + partial_result_set_2_args = {"stats": stats_pb} + if use_multiplexed: + partial_result_set_2_args["precommit_token"] = PRECOMMIT_TOKEN_1 + partial_result_set_2 = PartialResultSet(**partial_result_set_2_args) + + result_sets = [partial_result_set_1, partial_result_set_2] + for i in range(len(result_sets)): result_sets[i].values.extend(VALUE_PBS[i]) + + database = session._database + api = database.spanner_api + api.streaming_read.return_value = _MockIterator(*result_sets) + + # [C] Execute read + # ---------------- + KEYS = [["bharney@example.com"], ["phred@example.com"]] keyset = KeySet(keys=KEYS) INDEX = "email-address-index" LIMIT = 20 - database = _Database( - directed_read_options=directed_read_options_at_client_level - ) - api = database.spanner_api = self._make_spanner_api() - api.streaming_read.return_value = _MockIterator(*result_sets) - session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = multi_use - derived._read_request_count = count - if not first: - derived._transaction_id = TXN_ID if request_options is None: request_options = RequestOptions() @@ -742,12 +771,10 @@ def _read_helper( directed_read_options=directed_read_options, ) - self.assertEqual(derived._read_request_count, count + 1) + # [D] Verify results + # ------------------ - if multi_use: - self.assertIs(result_set._source, derived) - else: - self.assertIsNone(result_set._source) + self.assertEqual(derived._total_read_request_count, count + 1) self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) @@ -759,11 +786,17 @@ def _read_helper( if multi_use: if first: - expected_transaction = TransactionSelector(begin=txn_options) + expected_transaction_selector_pb = TransactionSelector( + begin=txn_options + ) else: - expected_transaction = TransactionSelector(id=TXN_ID) + expected_transaction_selector_pb = TransactionSelector( + id=TRANSACTION_ID + ) else: - expected_transaction = TransactionSelector(single_use=txn_options) + expected_transaction_selector_pb = TransactionSelector( + single_use=txn_options + ) if partition is not None: expected_limit = 0 @@ -781,11 +814,11 @@ def _read_helper( ) expected_request = ReadRequest( - session=self.SESSION_NAME, + session=session.name, table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), - transaction=expected_transaction, + transaction=expected_transaction_selector_pb, index=INDEX, limit=expected_limit, partition_token=partition, @@ -802,10 +835,18 @@ def _read_helper( self.assertSpanAttributes( "CloudSpanner._Derived.read", attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) + _build_base_attributes(database), + table_id=TABLE_NAME, + columns=tuple(COLUMNS), ), ) + if first: + self.assertEqual(derived._transaction_id, TRANSACTION_ID) + + if use_multiplexed: + self.assertEqual(derived._precommit_token, PRECOMMIT_TOKEN_2) + def test_read_wo_multi_use(self): self._read_helper(multi_use=False) @@ -882,9 +923,13 @@ def test_read_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) - def test_execute_sql_partitioned_not_implemented_for_multiplexed(self): - enable_multiplexed_sessions() + def test_read_w_multi_use_w_first(self): + self._read_helper(first=True, multi_use=True) + def test_read_w_precommit_tokens(self): + self._read_helper(multi_use=True, use_multiplexed=True) + + def test_execute_sql_partitioned_not_implemented_for_multiplexed(self): database = ( self._build_database_with_partitioned_not_implemented_for_multiplexed() ) @@ -910,12 +955,14 @@ def test_execute_sql_other_error(self): with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) - self.assertEqual(derived._execute_sql_count, 1) + self.assertEqual(derived._execute_sql_request_count, 1) self.assertSpanAttributes( "CloudSpanner._Derived.execute_sql", status=StatusCode.ERROR, - attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), + attributes=dict( + _build_base_attributes(database), **{"db.statement": SQL_QUERY} + ), ) def _execute_sql_helper( @@ -931,11 +978,11 @@ def _execute_sql_helper( retry=gapic_v1.method.DEFAULT, directed_read_options=None, directed_read_options_at_client_level=None, + use_multiplexed=False, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( PartialResultSet, - ResultSetMetadata, ResultSetStats, ) from google.cloud.spanner_v1 import ( @@ -950,6 +997,32 @@ def _execute_sql_helper( _merge_query_options, ) + # [A] Build derived + # ----------------- + + session = build_session( + database={ + "instance": { + "client": { + "directed_read_options": directed_read_options_at_client_level + } + } + } + ) + + session._session_id = self.SESSION_ID + + derived = self._makeDerived(session) + derived._multi_use = multi_use + derived._total_read_request_count = count + derived._execute_sql_request_count = sql_count + + if not first: + derived._transaction_id = TRANSACTION_ID + + # [B] Build results + # ----------------- + VALUES = [["bharney", "rhubbyl", 31], ["phred", "phlyntstone", 32]] VALUE_PBS = [[_make_value_pb(item) for item in row] for row in VALUES] MODE = 2 # PROFILE @@ -960,29 +1033,42 @@ def _execute_sql_helper( StructType.Field(name="age", type_=Type(code=TypeCode.INT64)), ] ) - metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + + # If the transaction has not already begun, the first result set will + # include metadata with information about the newly-begun transaction. + metadata_pb = build_result_set_metadata_pb( + row_type=struct_type_pb, + transaction={"id": TRANSACTION_ID} if first else None, + ) + stats_pb = ResultSetStats( query_stats=Struct(fields={"rows_returned": _make_value_pb(2)}) ) - result_sets = [ - PartialResultSet(metadata=metadata_pb), - PartialResultSet(stats=stats_pb), - ] + + # Precommit tokens will be included in the result sets if the transaction is on + # a multiplexed session. Return the precommit tokens out of order to verify that + # the transaction tracks the one with the highest sequence number. + partial_result_set_1_args = {"metadata": metadata_pb} + if use_multiplexed: + partial_result_set_1_args["precommit_token"] = PRECOMMIT_TOKEN_2 + partial_result_set_1 = PartialResultSet(**partial_result_set_1_args) + + partial_result_set_2_args = {"stats": stats_pb} + if use_multiplexed: + partial_result_set_2_args["precommit_token"] = PRECOMMIT_TOKEN_1 + partial_result_set_2 = PartialResultSet(**partial_result_set_2_args) + + result_sets = [partial_result_set_1, partial_result_set_2] + for i in range(len(result_sets)): result_sets[i].values.extend(VALUE_PBS[i]) - iterator = _MockIterator(*result_sets) - database = _Database( - directed_read_options=directed_read_options_at_client_level - ) - api = database.spanner_api = self._make_spanner_api() - api.execute_streaming_sql.return_value = iterator - session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = multi_use - derived._read_request_count = count - derived._execute_sql_count = sql_count - if not first: - derived._transaction_id = TXN_ID + + database = session._database + api = database.spanner_api + api.execute_streaming_sql.return_value = _MockIterator(*result_sets) + + # [C] Execute SQL + # --------------- if request_options is None: request_options = RequestOptions() @@ -1002,12 +1088,11 @@ def _execute_sql_helper( directed_read_options=directed_read_options, ) - self.assertEqual(derived._read_request_count, count + 1) + # [D] Verify results + # ------------------ - if multi_use: - self.assertIs(result_set._source, derived) - else: - self.assertIsNone(result_set._source) + self.assertEqual(derived._total_read_request_count, count + 1) + self.assertEqual(derived._execute_sql_request_count, sql_count + 1) self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) @@ -1019,11 +1104,17 @@ def _execute_sql_helper( if multi_use: if first: - expected_transaction = TransactionSelector(begin=txn_options) + expected_transaction_selector_pb = TransactionSelector( + begin=txn_options + ) else: - expected_transaction = TransactionSelector(id=TXN_ID) + expected_transaction_selector_pb = TransactionSelector( + id=TRANSACTION_ID + ) else: - expected_transaction = TransactionSelector(single_use=txn_options) + expected_transaction_selector_pb = TransactionSelector( + single_use=txn_options + ) expected_params = Struct( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} @@ -1035,10 +1126,9 @@ def _execute_sql_helper( expected_query_options, query_options ) - if derived._read_only: - # Transaction tag is ignored for read only requests. - expected_request_options = request_options - expected_request_options.transaction_tag = None + # Transaction tag is ignored for read only requests. + expected_request_options = request_options + expected_request_options.transaction_tag = None expected_directed_read_options = ( directed_read_options @@ -1047,9 +1137,9 @@ def _execute_sql_helper( ) expected_request = ExecuteSqlRequest( - session=self.SESSION_NAME, + session=session.name, sql=SQL_QUERY_WITH_PARAM, - transaction=expected_transaction, + transaction=expected_transaction_selector_pb, params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, @@ -1066,14 +1156,21 @@ def _execute_sql_helper( retry=retry, ) - self.assertEqual(derived._execute_sql_count, sql_count + 1) - self.assertSpanAttributes( "CloudSpanner._Derived.execute_sql", status=StatusCode.OK, - attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), + attributes=dict( + _build_base_attributes(database), + **{"db.statement": SQL_QUERY_WITH_PARAM}, + ), ) + if first: + self.assertEqual(derived._transaction_id, TRANSACTION_ID) + + if use_multiplexed: + self.assertEqual(derived._precommit_token, PRECOMMIT_TOKEN_2) + def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False) @@ -1162,6 +1259,9 @@ def test_execute_sql_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) + def test_execute_sql_w_precommit_tokens(self): + self._execute_sql_helper(multi_use=True, use_multiplexed=True) + def _partition_read_helper( self, multi_use, @@ -1197,7 +1297,7 @@ def _partition_read_helper( derived = self._makeDerived(session) derived._multi_use = multi_use if w_txn: - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID tokens = list( derived.partition_read( TABLE_NAME, @@ -1213,14 +1313,14 @@ def _partition_read_helper( self.assertEqual(tokens, [token_1, token_2]) - expected_txn_selector = TransactionSelector(id=TXN_ID) + expected_txn_selector = TransactionSelector(id=TRANSACTION_ID) expected_partition_options = PartitionOptions( partition_size_bytes=size, max_partitions=max_partitions ) expected_request = PartitionReadRequest( - session=self.SESSION_NAME, + session=session.name, table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), @@ -1239,7 +1339,7 @@ def _partition_read_helper( ) want_span_attributes = dict( - BASE_ATTRIBUTES, + _build_base_attributes(database), table_id=TABLE_NAME, columns=tuple(COLUMNS), ) @@ -1260,8 +1360,6 @@ def test_partition_read_wo_existing_transaction_raises(self): self._partition_read_helper(multi_use=True, w_txn=False) def test_partition_read_multiplexed_not_implemented_error(self): - enable_multiplexed_sessions() - database = ( self._build_database_with_partitioned_not_implemented_for_multiplexed() ) @@ -1270,7 +1368,7 @@ def test_partition_read_multiplexed_not_implemented_error(self): session.create() derived = self._makeDerived(session) derived._multi_use = True - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID with self.assertRaises(NotImplementedError): list(derived.partition_read(TABLE_NAME, COLUMNS, KeySet(all_=True))) @@ -1287,7 +1385,7 @@ def test_partition_read_other_error(self): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID with self.assertRaises(RuntimeError): list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) @@ -1296,7 +1394,9 @@ def test_partition_read_other_error(self): "CloudSpanner._Derived.partition_read", status=StatusCode.ERROR, attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) + _build_base_attributes(database), + table_id=TABLE_NAME, + columns=tuple(COLUMNS), ), ) @@ -1327,7 +1427,7 @@ def test_partition_read_w_retry(self): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) @@ -1390,7 +1490,7 @@ def _partition_query_helper( derived = self._makeDerived(session) derived._multi_use = multi_use if w_txn: - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID tokens = list( derived.partition_query( @@ -1410,14 +1510,14 @@ def _partition_query_helper( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) - expected_txn_selector = TransactionSelector(id=TXN_ID) + expected_txn_selector = TransactionSelector(id=TRANSACTION_ID) expected_partition_options = PartitionOptions( partition_size_bytes=size, max_partitions=max_partitions ) expected_request = PartitionQueryRequest( - session=self.SESSION_NAME, + session=session.name, sql=SQL_QUERY_WITH_PARAM, transaction=expected_txn_selector, params=expected_params, @@ -1437,12 +1537,13 @@ def _partition_query_helper( self.assertSpanAttributes( "CloudSpanner._Derived.partition_query", status=StatusCode.OK, - attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), + attributes=dict( + _build_base_attributes(database), + **{"db.statement": SQL_QUERY_WITH_PARAM}, + ), ) def test_partition_query_partitioned_not_implemented_for_multiplexed(self): - enable_multiplexed_sessions() - database = ( self._build_database_with_partitioned_not_implemented_for_multiplexed() ) @@ -1451,7 +1552,7 @@ def test_partition_query_partitioned_not_implemented_for_multiplexed(self): session.create() derived = self._makeDerived(session) derived._multi_use = True - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID with self.assertRaises(NotImplementedError): list(derived.partition_query(SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES)) @@ -1467,7 +1568,7 @@ def test_partition_query_other_error(self): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True - derived._transaction_id = TXN_ID + derived._transaction_id = TRANSACTION_ID with self.assertRaises(RuntimeError): list(derived.partition_query(SQL_QUERY)) @@ -1475,7 +1576,9 @@ def test_partition_query_other_error(self): self.assertSpanAttributes( "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, - attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), + attributes=dict( + _build_base_attributes(database), **{"db.statement": SQL_QUERY} + ), ) def test_partition_query_single_use_raises(self): @@ -1672,9 +1775,9 @@ def test_ctor_w_multi_use_and_exact_staleness(self): def test__make_txn_selector_w_transaction_id(self): session = _Session() snapshot = self._make_one(session) - snapshot._transaction_id = TXN_ID + snapshot._transaction_id = TRANSACTION_ID selector = snapshot._make_txn_selector() - self.assertEqual(selector.id, TXN_ID) + self.assertEqual(selector.id, TRANSACTION_ID) def test__make_txn_selector_strong(self): session = _Session() @@ -1777,14 +1880,14 @@ def test_begin_wo_multi_use(self): def test_begin_w_read_request_count_gt_0(self): session = _Session() snapshot = self._make_one(session, multi_use=True) - snapshot._read_request_count = 1 + snapshot._total_read_request_count = 1 with self.assertRaises(ValueError): snapshot.begin() def test_begin_w_existing_txn_id(self): session = _Session() snapshot = self._make_one(session, multi_use=True) - snapshot._transaction_id = TXN_ID + snapshot._transaction_id = TRANSACTION_ID with self.assertRaises(ValueError): snapshot.begin() @@ -1810,7 +1913,7 @@ def test_begin_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.ERROR, - attributes=BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) def test_begin_w_retry(self): @@ -1823,7 +1926,7 @@ def test_begin_w_retry(self): api = database.spanner_api = self._make_spanner_api() database.spanner_api.begin_transaction.side_effect = [ InternalServerError("Received unexpected EOS on DATA frame from server"), - TransactionPB(id=TXN_ID), + TransactionPB(id=TRANSACTION_ID), ] timestamp = _makeTimestamp() session = _Session(database) @@ -1839,7 +1942,7 @@ def test_begin_ok_exact_staleness(self): TransactionOptions, ) - transaction_pb = TransactionPB(id=TXN_ID) + transaction_pb = TransactionPB(id=TRANSACTION_ID) database = _Database() api = database.spanner_api = self._make_spanner_api() api.begin_transaction.return_value = transaction_pb @@ -1849,8 +1952,8 @@ def test_begin_ok_exact_staleness(self): txn_id = snapshot.begin() - self.assertEqual(txn_id, TXN_ID) - self.assertEqual(snapshot._transaction_id, TXN_ID) + self.assertEqual(txn_id, TRANSACTION_ID) + self.assertEqual(snapshot._transaction_id, TRANSACTION_ID) expected_duration = Duration(seconds=SECONDS, nanos=MICROS * 1000) expected_txn_options = TransactionOptions( @@ -1868,7 +1971,7 @@ def test_begin_ok_exact_staleness(self): self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.OK, - attributes=BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) def test_begin_ok_exact_strong(self): @@ -1877,7 +1980,7 @@ def test_begin_ok_exact_strong(self): TransactionOptions, ) - transaction_pb = TransactionPB(id=TXN_ID) + transaction_pb = TransactionPB(id=TRANSACTION_ID) database = _Database() api = database.spanner_api = self._make_spanner_api() api.begin_transaction.return_value = transaction_pb @@ -1886,8 +1989,8 @@ def test_begin_ok_exact_strong(self): txn_id = snapshot.begin() - self.assertEqual(txn_id, TXN_ID) - self.assertEqual(snapshot._transaction_id, TXN_ID) + self.assertEqual(txn_id, TRANSACTION_ID) + self.assertEqual(snapshot._transaction_id, TRANSACTION_ID) expected_txn_options = TransactionOptions( read_only=TransactionOptions.ReadOnly( @@ -1904,10 +2007,27 @@ def test_begin_ok_exact_strong(self): self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.OK, - attributes=BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) +def _build_base_attributes(database: Database) -> dict: + """Builds and returns the base attributes for the given database.""" + from tests._helpers import enrich_with_otel_scope + + return enrich_with_otel_scope( + { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": database.name, + "net.host.name": "spanner.googleapis.com", + "gcp.client.service": "spanner", + "gcp.client.version": LIB_VERSION, + "gcp.client.repo": "googleapis/python-spanner", + } + ) + + class _Client(object): def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 8bd95c7228..ddc867f1b0 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -46,6 +46,7 @@ from google.api_core import gapic_v1 +from tests._builders import build_result_set_metadata_pb from tests._helpers import OpenTelemetryBase TABLE_NAME = "citizens" @@ -150,7 +151,7 @@ def _execute_update_helper( transaction.transaction_tag = self.TRANSACTION_TAG transaction.exclude_txn_from_change_streams = exclude_txn_from_change_streams transaction.isolation_level = isolation_level - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count row_count = transaction.execute_update( DML_QUERY_WITH_PARAM, @@ -244,8 +245,8 @@ def _execute_sql_helper( result_sets[i].values.extend(VALUE_PBS[i]) iterator = _MockIterator(*result_sets) api.execute_streaming_sql.return_value = iterator - transaction._execute_sql_count = sql_count - transaction._read_request_count = count + transaction._execute_sql_request_count = sql_count + transaction._total_read_request_count = count result_set = transaction.execute_sql( SQL_QUERY_WITH_PARAM, @@ -260,12 +261,12 @@ def _execute_sql_helper( directed_read_options=directed_read_options, ) - self.assertEqual(transaction._read_request_count, count + 1) + self.assertEqual(transaction._total_read_request_count, count + 1) self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - self.assertEqual(transaction._execute_sql_count, sql_count + 1) + self.assertEqual(transaction._execute_sql_request_count, sql_count + 1) def _execute_sql_expected_request( self, @@ -350,7 +351,7 @@ def _read_helper( result_sets[i].values.extend(VALUE_PBS[i]) api.streaming_read.return_value = _MockIterator(*result_sets) - transaction._read_request_count = count + transaction._total_read_request_count = count if partition is not None: # 'limit' and 'partition' incompatible result_set = transaction.read( @@ -377,9 +378,7 @@ def _read_helper( directed_read_options=directed_read_options, ) - self.assertEqual(transaction._read_request_count, count + 1) - - self.assertIs(result_set._source, transaction) + self.assertEqual(transaction._total_read_request_count, count + 1) self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) @@ -431,7 +430,6 @@ def _read_helper_expected_request( def _batch_update_helper( self, transaction, - database, api, error_after=None, count=0, @@ -449,8 +447,10 @@ def _batch_update_helper( else: expected_status = Status(code=200) expected_row_counts = [stats.row_count_exact for stats in stats_pbs] - transaction_pb = transaction_type.Transaction(id=self.TRANSACTION_ID) - metadata_pb = ResultSetMetadata(transaction=transaction_pb) + + metadata_pb = build_result_set_metadata_pb( + transaction={"id": self.TRANSACTION_ID} + ) result_sets_pb = [ ResultSet(stats=stats_pb, metadata=metadata_pb) for stats_pb in stats_pbs ] @@ -462,7 +462,7 @@ def _batch_update_helper( api.execute_batch_dml.return_value = response transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count status, row_counts = transaction.batch_update( dml_statements, request_options=RequestOptions() @@ -470,7 +470,7 @@ def _batch_update_helper( self.assertEqual(status, expected_status) self.assertEqual(row_counts, expected_row_counts) - self.assertEqual(transaction._execute_sql_count, count + 1) + self.assertEqual(transaction._execute_sql_request_count, count + 1) def _batch_update_expected_request(self, begin=True, count=0): if begin is True: @@ -564,7 +564,7 @@ def test_transaction_should_include_begin_with_first_batch_update(self): session = _Session(database) api = database.spanner_api = self._make_spanner_api() transaction = self._make_one(session) - self._batch_update_helper(transaction=transaction, database=database, api=api) + self._batch_update_helper(transaction=transaction, api=api) api.execute_batch_dml.assert_called_once_with( request=self._batch_update_expected_request(), metadata=[ @@ -631,9 +631,7 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( session = _Session(database) api = database.spanner_api = self._make_spanner_api() transaction = self._make_one(session) - self._batch_update_helper( - transaction=transaction, database=database, api=api, error_after=2 - ) + self._batch_update_helper(transaction=transaction, api=api, error_after=2) api.execute_batch_dml.assert_called_once_with( request=self._batch_update_expected_request(begin=True), metadata=[ @@ -776,7 +774,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): timeout=TIMEOUT, ) - self._batch_update_helper(transaction=transaction, database=database, api=api) + self._batch_update_helper(transaction=transaction, api=api) api.execute_batch_dml.assert_called_once_with( request=self._batch_update_expected_request(begin=False), metadata=[ @@ -792,7 +790,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se api = database.spanner_api = self._make_spanner_api() session = _Session(database) transaction = self._make_one(session) - self._batch_update_helper(transaction=transaction, database=database, api=api) + self._batch_update_helper(transaction=transaction, api=api) api.execute_batch_dml.assert_called_once_with( request=self._batch_update_expected_request(), metadata=[ @@ -841,7 +839,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ for thread in threads: thread.join() - self._batch_update_helper(transaction=transaction, database=database, api=api) + self._batch_update_helper(transaction=transaction, api=api) api.execute_sql.assert_any_call( request=self._execute_update_expected_request(database), @@ -887,13 +885,13 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ threads.append( threading.Thread( target=self._batch_update_helper, - kwargs={"transaction": transaction, "database": database, "api": api}, + kwargs={"transaction": transaction, "api": api}, ) ) threads.append( threading.Thread( target=self._batch_update_helper, - kwargs={"transaction": transaction, "database": database, "api": api}, + kwargs={"transaction": transaction, "api": api}, ) ) for thread in threads: diff --git a/tests/unit/test_streamed.py b/tests/unit/test_streamed.py index 83aa25a9d1..dac64c4745 100644 --- a/tests/unit/test_streamed.py +++ b/tests/unit/test_streamed.py @@ -27,21 +27,10 @@ def _getTargetClass(self): def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) - def test_ctor_defaults(self): + def test_ctor(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) self.assertIs(streamed._response_iterator, iterator) - self.assertIsNone(streamed._source) - self.assertEqual(list(streamed), []) - self.assertIsNone(streamed.metadata) - self.assertIsNone(streamed.stats) - - def test_ctor_w_source(self): - iterator = _MockCancellableIterator() - source = object() - streamed = self._make_one(iterator, source=source) - self.assertIs(streamed._response_iterator, iterator) - self.assertIs(streamed._source, source) self.assertEqual(list(streamed), []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) @@ -790,46 +779,21 @@ def test_consume_next_empty(self): def test_consume_next_first_set_partial(self): from google.cloud.spanner_v1 import TypeCode - TXN_ID = b"DEADBEEF" - FIELDS = [ - self._make_scalar_field("full_name", TypeCode.STRING), - self._make_scalar_field("age", TypeCode.INT64), - self._make_scalar_field("married", TypeCode.BOOL), - ] - metadata = self._make_result_set_metadata(FIELDS, transaction_id=TXN_ID) - BARE = ["Phred Phlyntstone", 42] - VALUES = [self._make_value(bare) for bare in BARE] - result_set = self._make_partial_result_set(VALUES, metadata=metadata) - iterator = _MockCancellableIterator(result_set) - source = mock.Mock(_transaction_id=None, spec=["_transaction_id"]) - streamed = self._make_one(iterator, source=source) - streamed._consume_next() - self.assertEqual(list(streamed), []) - self.assertEqual(streamed._current_row, BARE) - self.assertEqual(streamed.metadata, metadata) - self.assertEqual(source._transaction_id, TXN_ID) - - def test_consume_next_first_set_partial_existing_txn_id(self): - from google.cloud.spanner_v1 import TypeCode - - TXN_ID = b"DEADBEEF" FIELDS = [ self._make_scalar_field("full_name", TypeCode.STRING), self._make_scalar_field("age", TypeCode.INT64), self._make_scalar_field("married", TypeCode.BOOL), ] - metadata = self._make_result_set_metadata(FIELDS, transaction_id=b"") + metadata = self._make_result_set_metadata(FIELDS) BARE = ["Phred Phlyntstone", 42] VALUES = [self._make_value(bare) for bare in BARE] result_set = self._make_partial_result_set(VALUES, metadata=metadata) iterator = _MockCancellableIterator(result_set) - source = mock.Mock(_transaction_id=TXN_ID, spec=["_transaction_id"]) - streamed = self._make_one(iterator, source=source) + streamed = self._make_one(iterator) streamed._consume_next() self.assertEqual(list(streamed), []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_w_partial_result(self): from google.cloud.spanner_v1 import TypeCode diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index ddc91ea522..6e33ddb697 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -11,23 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime - -import mock - -from google.cloud.spanner_v1 import RequestOptions -from google.cloud.spanner_v1 import DefaultTransactionOptions -from google.cloud.spanner_v1 import Type -from google.cloud.spanner_v1 import TypeCode from google.api_core.retry import Retry from google.api_core import gapic_v1 +from google.cloud.spanner_v1 import ( + CommitRequest, + CommitResponse, + RequestOptions, + TransactionOptions, + Type, + TypeCode, + ResultSetStats, +) + +from google.cloud.spanner_v1.transaction import Transaction + +from tests._builders import ( + build_precommit_token_pb, + build_result_set_metadata_pb, + build_result_set_pb, + build_transaction, + build_transaction_pb, +) from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, LIB_VERSION, OpenTelemetryBase, StatusCode, - enrich_with_otel_scope, ) TABLE_NAME = "citizens" @@ -47,225 +59,220 @@ PARAMS = {"age": 30} PARAM_TYPES = {"age": Type(code=TypeCode.INT64)} +TRANSACTION_ID = b"transaction-id" -class TestTransaction(OpenTelemetryBase): - PROJECT_ID = "project-id" - INSTANCE_ID = "instance-id" - INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID - DATABASE_ID = "database-id" - DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID - SESSION_ID = "session-id" - SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID - TRANSACTION_ID = b"DEADBEEF" - TRANSACTION_TAG = "transaction-tag" - - BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "testing", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", - } - enrich_with_otel_scope(BASE_ATTRIBUTES) - - def _getTargetClass(self): - from google.cloud.spanner_v1.transaction import Transaction +PRECOMMIT_TOKEN_0 = build_precommit_token_pb(precommit_token=b"0", seq_num=0) +PRECOMMIT_TOKEN_1 = build_precommit_token_pb(precommit_token=b"1", seq_num=1) +PRECOMMIT_TOKEN_2 = build_precommit_token_pb(precommit_token=b"2", seq_num=2) - return Transaction +TRANSACTION_TAG = "transaction-tag" - def _make_one(self, session, *args, **kwargs): - transaction = self._getTargetClass()(session, *args, **kwargs) - session._transaction = transaction - return transaction - - def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - - return mock.create_autospec(SpannerClient, instance=True) +class TestTransaction(OpenTelemetryBase): def test_ctor_session_w_existing_txn(self): - session = _Session() - session._transaction = object() + transaction = build_transaction() with self.assertRaises(ValueError): - self._make_one(session) + Transaction(transaction._session) def test_ctor_defaults(self): - session = _Session() - transaction = self._make_one(session) + from tests._builders import build_session + + session = build_session() + transaction = Transaction(session=session) + self.assertIs(transaction._session, session) self.assertIsNone(transaction._transaction_id) self.assertIsNone(transaction.committed) self.assertFalse(transaction.rolled_back) self.assertTrue(transaction._multi_use) - self.assertEqual(transaction._execute_sql_count, 0) + self.assertEqual(transaction._execute_sql_request_count, 0) def test__check_state_already_committed(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.committed = object() + transaction = build_transaction() + transaction.begin() + transaction.commit() + with self.assertRaises(ValueError): transaction._check_state() def test__check_state_already_rolled_back(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.rolled_back = True + transaction = build_transaction() + transaction.begin() + transaction.rollback() + with self.assertRaises(ValueError): transaction._check_state() def test__check_state_ok(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + transaction._check_state() # does not raise + + transaction.begin() transaction._check_state() # does not raise def test__make_txn_selector(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + + begin_transaction = transaction._session._database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=TRANSACTION_ID) + + transaction.begin() + selector = transaction._make_txn_selector() - self.assertEqual(selector.id, self.TRANSACTION_ID) + self.assertEqual(selector.id, TRANSACTION_ID) def test_begin_already_begun(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + transaction.begin() + + self.reset() with self.assertRaises(ValueError): transaction.begin() self.assertNoSpans() def test_begin_already_rolled_back(self): - session = _Session() - transaction = self._make_one(session) - transaction.rolled_back = True + transaction = build_transaction() + transaction.begin() + transaction.rollback() + + self.reset() with self.assertRaises(ValueError): transaction.begin() self.assertNoSpans() def test_begin_already_committed(self): - session = _Session() - transaction = self._make_one(session) - transaction.committed = object() + transaction = build_transaction() + transaction.begin() + transaction.commit() + + self.reset() with self.assertRaises(ValueError): transaction.begin() self.assertNoSpans() def test_begin_w_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() + transaction = build_transaction() + + database = transaction._session._database database.spanner_api.begin_transaction.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) + self.reset() with self.assertRaises(RuntimeError): transaction.begin() self.assertSpanAttributes( "CloudSpanner.Transaction.begin", status=StatusCode.ERROR, - attributes=TestTransaction.BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) def test_begin_ok(self): - from google.cloud.spanner_v1 import Transaction as TransactionPB + transaction = build_transaction() + session = transaction._session + database = session._database - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) - database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb - ) - session = _Session(database) - transaction = self._make_one(session) + api = database.spanner_api + begin_transaction = api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=TRANSACTION_ID) - txn_id = transaction.begin() + self.reset() + transaction_id = transaction.begin() - self.assertEqual(txn_id, self.TRANSACTION_ID) - self.assertEqual(transaction._transaction_id, self.TRANSACTION_ID) + self.assertEqual(transaction_id, TRANSACTION_ID) + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + self.assertIsNone(transaction._precommit_token) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - self.assertTrue(type(txn_options).pb(txn_options).HasField("read_write")) - self.assertEqual( - metadata, - [ + begin_transaction.assert_called_once_with( + session=session.name, + options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), + metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), ], ) self.assertSpanAttributes( - "CloudSpanner.Transaction.begin", attributes=TestTransaction.BASE_ATTRIBUTES + "CloudSpanner.Transaction.begin", + attributes=_build_base_attributes(database), ) def test_begin_w_retry(self): - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - ) from google.api_core.exceptions import InternalServerError - database = _Database() - api = database.spanner_api = self._make_spanner_api() - database.spanner_api.begin_transaction.side_effect = [ + transaction = build_transaction() + + api = transaction._session._database.spanner_api + begin_transaction = api.begin_transaction + begin_transaction.side_effect = [ InternalServerError("Received unexpected EOS on DATA frame from server"), - TransactionPB(id=self.TRANSACTION_ID), + build_transaction_pb(id=TRANSACTION_ID), ] - session = _Session(database) - transaction = self._make_one(session) + transaction_id = transaction.begin() + + self.assertEqual(begin_transaction.call_count, 2) + self.assertEqual(transaction_id, TRANSACTION_ID) + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + + def test_begin_w_precommit_token(self): + transaction = build_transaction() + + api = transaction._session._database.spanner_api + api.begin_transaction.return_value = build_transaction_pb( + id=TRANSACTION_ID, precommit_token=PRECOMMIT_TOKEN_0 + ) + transaction.begin() - self.assertEqual(api.begin_transaction.call_count, 2) + self.assertEqual(transaction._precommit_token, PRECOMMIT_TOKEN_0) def test_rollback_not_begun(self): - database = _Database() - api = database.spanner_api = self._make_spanner_api() - session = _Session(database) - transaction = self._make_one(session) + transaction = build_transaction() + self.reset() transaction.rollback() self.assertTrue(transaction.rolled_back) # Since there was no transaction to be rolled back, rollback rpc is not called. + api = transaction._session._database.spanner_api api.rollback.assert_not_called() self.assertNoSpans() def test_rollback_already_committed(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.committed = object() + transaction = build_transaction() + transaction.begin() + transaction.commit() + + self.reset() with self.assertRaises(ValueError): transaction.rollback() self.assertNoSpans() def test_rollback_already_rolled_back(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.rolled_back = True + transaction = build_transaction() + transaction.rollback() + + self.reset() with self.assertRaises(ValueError): transaction.rollback() self.assertNoSpans() def test_rollback_w_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.rollback.side_effect = RuntimeError("other error") - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + + transaction.begin() transaction.insert(TABLE_NAME, COLUMNS, VALUES) + database = transaction._session._database + database.spanner_api.rollback.side_effect = RuntimeError() + + self.reset() with self.assertRaises(RuntimeError): transaction.rollback() @@ -274,31 +281,28 @@ def test_rollback_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Transaction.rollback", status=StatusCode.ERROR, - attributes=TestTransaction.BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) def test_rollback_ok(self): - from google.protobuf.empty_pb2 import Empty - - empty_pb = Empty() - database = _Database() - api = database.spanner_api = _FauxSpannerAPI(_rollback_response=empty_pb) - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + session = transaction._session + database = session._database + api = database.spanner_api + + transaction.begin() transaction.replace(TABLE_NAME, COLUMNS, VALUES) + self.reset() transaction.rollback() self.assertTrue(transaction.rolled_back) self.assertIsNone(session._transaction) - session_id, txn_id, metadata = api._rolled_back - self.assertEqual(session_id, session.name) - self.assertEqual(txn_id, self.TRANSACTION_ID) - self.assertEqual( - metadata, - [ + api.rollback.assert_called_once_with( + session=session.name, + transaction_id=transaction._transaction_id, + metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), ], @@ -306,14 +310,13 @@ def test_rollback_ok(self): self.assertSpanAttributes( "CloudSpanner.Transaction.rollback", - attributes=TestTransaction.BASE_ATTRIBUTES, + attributes=_build_base_attributes(database), ) def test_commit_not_begun(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - session = _Session(database) - transaction = self._make_one(session) + transaction = build_transaction() + + self.reset() with self.assertRaises(ValueError): transaction.commit() @@ -340,12 +343,11 @@ def test_commit_not_begun(self): assert got_span_events_statuses == want_span_events_statuses def test_commit_already_committed(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.committed = object() + transaction = build_transaction() + transaction.begin() + transaction.commit() + + self.reset() with self.assertRaises(ValueError): transaction.commit() @@ -372,12 +374,10 @@ def test_commit_already_committed(self): assert got_span_events_statuses == want_span_events_statuses def test_commit_already_rolled_back(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.rolled_back = True + transaction = build_transaction() + transaction.rollback() + + self.reset() with self.assertRaises(ValueError): transaction.commit() @@ -404,14 +404,15 @@ def test_commit_already_rolled_back(self): assert got_span_events_statuses == want_span_events_statuses def test_commit_w_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.commit.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + + transaction.begin() transaction.replace(TABLE_NAME, COLUMNS, VALUES) + database = transaction._session._database + database.spanner_api.commit.side_effect = RuntimeError() + + self.reset() with self.assertRaises(RuntimeError): transaction.commit() @@ -420,7 +421,7 @@ def test_commit_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, - attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), + attributes=dict(_build_base_attributes(database), num_mutations=1), ) def _commit_helper( @@ -432,81 +433,90 @@ def _commit_helper( ): import datetime + from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import CommitResponse from google.cloud.spanner_v1.keyset import KeySet from google.cloud._helpers import UTC - now = datetime.datetime.utcnow().replace(tzinfo=UTC) - keys = [[0], [1], [2]] - keyset = KeySet(keys=keys) + # [A] Build transaction + # --------------------- + + transaction = build_transaction() + session = transaction._session + database = session._database + api = database.spanner_api + + transaction.transaction_tag = TRANSACTION_TAG + + # Build response + # -------------- + + now = datetime.datetime.now(tz=UTC) + + # TODO - test retry where precommit token is returned response = CommitResponse(commit_timestamp=now) if return_commit_stats: response.commit_stats.mutation_count = 4 - database = _Database() - api = database.spanner_api = _FauxSpannerAPI(_commit_response=response) - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG + + commit = api.commit + commit.return_value = response + + # [C] Execute commit + # ------------------ + + transaction.begin() if mutate: + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) transaction.delete(TABLE_NAME, keyset) + self.reset() transaction.commit( return_commit_stats=return_commit_stats, request_options=request_options, max_commit_delay=max_commit_delay_in, ) + # [D] Verify results + # ------------------ + self.assertEqual(transaction.committed, now) self.assertIsNone(session._transaction) - ( - session_id, - mutations, - txn_id, - actual_request_options, - max_commit_delay, - metadata, - ) = api._committed - if request_options is None: - expected_request_options = RequestOptions( - transaction_tag=self.TRANSACTION_TAG - ) + expected_request_options = RequestOptions(transaction_tag=TRANSACTION_TAG) elif type(request_options) is dict: expected_request_options = RequestOptions(request_options) - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request_options.request_tag = None else: expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request_options.request_tag = None - self.assertEqual(max_commit_delay_in, max_commit_delay) - self.assertEqual(session_id, session.name) - self.assertEqual(txn_id, self.TRANSACTION_ID) - self.assertEqual(mutations, transaction._mutations) - self.assertEqual( - metadata, - [ - ("google-cloud-resource-prefix", database.name), - ("x-goog-spanner-route-to-leader", "true"), - ], + expected_request = CommitRequest( + session=session.name, + transaction_id=transaction._transaction_id, + mutations=transaction._mutations, + return_commit_stats=return_commit_stats, + max_commit_delay=max_commit_delay_in, + request_options=expected_request_options, + precommit_token=transaction._precommit_token, + ) + expected_metadata = [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ] + + commit.assert_called_once_with( + request=expected_request, + metadata=expected_metadata, ) - self.assertEqual(actual_request_options, expected_request_options) if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) - self.assertSpanAttributes( - "CloudSpanner.Transaction.commit", - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, - num_mutations=len(transaction._mutations), - ), - ) - if not HAS_OPENTELEMETRY_INSTALLED: return @@ -565,9 +575,7 @@ def test__make_params_pb_w_params_w_param_types(self): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1._helpers import _make_value_pb - session = _Session() - transaction = self._make_one(session) - + transaction = build_transaction() params_pb = transaction._make_params_pb(PARAMS, PARAM_TYPES) expected_params = Struct( @@ -576,13 +584,13 @@ def test__make_params_pb_w_params_w_param_types(self): self.assertEqual(params_pb, expected_params) def test_execute_update_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.execute_sql.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + transaction.begin() + + api = transaction._session._database.spanner_api + api.execute_sql.side_effect = RuntimeError() + self.reset() with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) @@ -593,10 +601,11 @@ def _execute_update_helper( request_options=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + begin=True, + use_multiplexed=False, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( - ResultSet, ResultSetStats, ) from google.cloud.spanner_v1 import TransactionSelector @@ -606,22 +615,47 @@ def _execute_update_helper( ) from google.cloud.spanner_v1 import ExecuteSqlRequest + # [A] Build transaction + # --------------------- + + transaction = build_transaction() + session = transaction._session + database = session._database + api = database.spanner_api + + transaction.transaction_tag = TRANSACTION_TAG + transaction._execute_sql_request_count = count + + if begin: + transaction.begin() + + # [B] Build results + # ----------------- + + # If the transaction had not already begun, the first result set will include + # metadata with information about the transaction. Precommit tokens will be + # included in the result sets if the transaction is on a multiplexed session. + transaction_id = TRANSACTION_ID if not begin else None + metadata_pb = build_result_set_metadata_pb(transaction={"id": transaction_id}) + precommit_token_pb = PRECOMMIT_TOKEN_0 if use_multiplexed else None + + api.execute_sql.return_value = build_result_set_pb( + stats=ResultSetStats(row_count_exact=1), + metadata=metadata_pb, + precommit_token=precommit_token_pb, + ) + + # [C] Execute SQL + # --------------- + MODE = 2 # PROFILE - stats_pb = ResultSetStats(row_count_exact=1) - database = _Database() - api = database.spanner_api = self._make_spanner_api() - api.execute_sql.return_value = ResultSet(stats=stats_pb) - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: request_options = RequestOptions(request_options) + self.reset() row_count = transaction.execute_update( DML_QUERY_WITH_PARAM, PARAMS, @@ -633,9 +667,19 @@ def _execute_update_helper( timeout=timeout, ) + # [D] Verify results + # ------------------ + self.assertEqual(row_count, 1) - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_transaction_selector_pb = ( + TransactionSelector(id=transaction._transaction_id) + if begin + else TransactionSelector( + begin=TransactionOptions(read_write=TransactionOptions.ReadWrite()) + ) + ) + expected_params = Struct( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) @@ -646,12 +690,12 @@ def _execute_update_helper( expected_query_options, query_options ) expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request = ExecuteSqlRequest( - session=self.SESSION_NAME, + session=session.name, sql=DML_QUERY_WITH_PARAM, - transaction=expected_transaction, + transaction=expected_transaction_selector_pb, params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, @@ -669,8 +713,8 @@ def _execute_update_helper( ], ) - self.assertEqual(transaction._execute_sql_count, count + 1) - want_span_attributes = dict(TestTransaction.BASE_ATTRIBUTES) + self.assertEqual(transaction._execute_sql_request_count, count + 1) + want_span_attributes = dict(_build_base_attributes(database)) want_span_attributes["db.statement"] = DML_QUERY_WITH_PARAM self.assertSpanAttributes( "CloudSpanner.Transaction.execute_update", @@ -678,8 +722,14 @@ def _execute_update_helper( attributes=want_span_attributes, ) - def test_execute_update_new_transaction(self): - self._execute_update_helper() + if not begin: + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + + if use_multiplexed: + self.assertEqual(transaction._precommit_token, PRECOMMIT_TOKEN_0) + + def test_execute_update_wo_begin(self): + self._execute_update_helper(begin=False) def test_execute_update_w_request_tag_success(self): request_options = RequestOptions( @@ -722,17 +772,15 @@ def test_execute_update_w_timeout_and_retry_params(self): self._execute_update_helper(retry=Retry(deadline=60), timeout=2.0) def test_execute_update_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.execute_sql.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction = build_transaction() + + api = transaction._session._database.spanner_api + api.execute_sql.side_effect = RuntimeError() with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) - self.assertEqual(transaction._execute_sql_count, 1) + self.assertEqual(transaction._execute_sql_request_count, 1) def test_execute_update_w_query_options(self): from google.cloud.spanner_v1 import ExecuteSqlRequest @@ -748,16 +796,8 @@ def test_execute_update_w_request_options(self): ) ) - def test_batch_update_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.execute_batch_dml.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - - with self.assertRaises(RuntimeError): - transaction.batch_update(statements=[DML_QUERY]) + def test_execute_update_w_precommit_token(self): + self._execute_update_helper(use_multiplexed=True) def _batch_update_helper( self, @@ -766,17 +806,34 @@ def _batch_update_helper( request_options=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + begin=True, + use_multiplexed=False, ): from google.rpc.status_pb2 import Status from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import param_types - from google.cloud.spanner_v1 import ResultSet - from google.cloud.spanner_v1 import ResultSetStats from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteBatchDmlResponse from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1._helpers import _make_value_pb + # [A] Build transaction + # --------------------- + + transaction = build_transaction() + session = transaction._session + database = session._database + api = database.spanner_api + + transaction.transaction_tag = TRANSACTION_TAG + transaction._execute_sql_request_count = count + + if begin: + transaction.begin() + + # [B] Build results + # ----------------- + insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)" insert_params = {"pkey": 12345, "desc": "DESCRIPTION"} insert_param_types = {"pkey": param_types.INT64, "desc": param_types.STRING} @@ -789,11 +846,15 @@ def _batch_update_helper( delete_dml, ] + # These precommit tokens are intentionally returned with sequence numbers out of order. + precommit_tokens = [PRECOMMIT_TOKEN_2, PRECOMMIT_TOKEN_0, PRECOMMIT_TOKEN_1] + stats_pbs = [ ResultSetStats(row_count_exact=1), ResultSetStats(row_count_exact=2), ResultSetStats(row_count_exact=3), ] + if error_after is not None: stats_pbs = stats_pbs[:error_after] expected_status = Status(code=400) @@ -801,24 +862,38 @@ def _batch_update_helper( expected_status = Status(code=200) expected_row_counts = [stats.row_count_exact for stats in stats_pbs] - response = ExecuteBatchDmlResponse( + result_sets = [] + for i in range(len(stats_pbs)): + result_set_args = {"stats": stats_pbs[i]} + + # If the transaction had not already begun, the first result + # set will include metadata with information about the transaction. + if not begin and i == 0: + result_set_args["metadata"] = build_result_set_metadata_pb( + transaction={"id": TRANSACTION_ID} + ) + + # Precommit tokens will be included in the result + # sets if the transaction is on a multiplexed session. + if use_multiplexed: + result_set_args["precommit_token"] = precommit_tokens[i] + + result_sets.append(build_result_set_pb(**result_set_args)) + + api.execute_batch_dml.return_value = ExecuteBatchDmlResponse( status=expected_status, - result_sets=[ResultSet(stats=stats_pb) for stats_pb in stats_pbs], + result_sets=result_sets, ) - database = _Database() - api = database.spanner_api = self._make_spanner_api() - api.execute_batch_dml.return_value = response - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count + + # [C] Execute batch DML + # --------------------- if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: request_options = RequestOptions(request_options) + self.reset() status, row_counts = transaction.batch_update( dml_statements, request_options=request_options, @@ -826,10 +901,20 @@ def _batch_update_helper( timeout=timeout, ) + # [D] Verify results + # ------------------ + self.assertEqual(status, expected_status) self.assertEqual(row_counts, expected_row_counts) - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_transaction_selector_pb = ( + TransactionSelector(id=transaction._transaction_id) + if begin + else TransactionSelector( + begin=TransactionOptions(read_write=TransactionOptions.ReadWrite()) + ) + ) + expected_insert_params = Struct( fields={ key: _make_value_pb(value) for (key, value) in insert_params.items() @@ -845,11 +930,11 @@ def _batch_update_helper( ExecuteBatchDmlRequest.Statement(sql=delete_dml), ] expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request = ExecuteBatchDmlRequest( - session=self.SESSION_NAME, - transaction=expected_transaction, + session=session.name, + transaction=expected_transaction_selector_pb, statements=expected_statements, seqno=count, request_options=expected_request_options, @@ -864,7 +949,16 @@ def _batch_update_helper( timeout=timeout, ) - self.assertEqual(transaction._execute_sql_count, count + 1) + self.assertEqual(transaction._execute_sql_request_count, count + 1) + + if not begin: + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + + if use_multiplexed: + self.assertEqual(transaction._precommit_token, PRECOMMIT_TOKEN_2) + + def test_batch_update_wo_begin(self): + self._batch_update_helper(begin=False) def test_batch_update_wo_errors(self): self._batch_update_helper( @@ -908,12 +1002,11 @@ def test_batch_update_error(self): from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode - database = _Database() - api = database.spanner_api = self._make_spanner_api() + transaction = build_transaction() + transaction.begin() + + api = transaction._session._database.spanner_api api.execute_batch_dml.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)" insert_params = {"pkey": 12345, "desc": "DESCRIPTION"} @@ -933,7 +1026,7 @@ def test_batch_update_error(self): with self.assertRaises(RuntimeError): transaction.batch_update(dml_statements) - self.assertEqual(transaction._execute_sql_count, 1) + self.assertEqual(transaction._execute_sql_request_count, 1) def test_batch_update_w_timeout_param(self): self._batch_update_helper(timeout=2.0) @@ -944,52 +1037,46 @@ def test_batch_update_w_retry_param(self): def test_batch_update_w_timeout_and_retry_params(self): self._batch_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0) + def test_batch_update_w_precommit_token(self): + self._batch_update_helper(use_multiplexed=True) + def test_context_mgr_success(self): - import datetime - from google.cloud.spanner_v1 import CommitResponse - from google.cloud.spanner_v1 import Transaction as TransactionPB - from google.cloud._helpers import UTC + transaction = build_transaction() + session = transaction._session + database = session._database + api = database.spanner_api - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) - now = datetime.datetime.utcnow().replace(tzinfo=UTC) - response = CommitResponse(commit_timestamp=now) - database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb, _commit_response=response - ) - session = _Session(database) - transaction = self._make_one(session) + begin = api.begin_transaction + transaction_id = TRANSACTION_ID + begin.return_value = build_transaction_pb(id=transaction_id) + + commit = api.commit + now = datetime.datetime.now(tz=datetime.timezone.utc) + commit.return_value = CommitResponse(commit_timestamp=now) with transaction: transaction.insert(TABLE_NAME, COLUMNS, VALUES) self.assertEqual(transaction.committed, now) - session_id, mutations, txn_id, _, _, metadata = api._committed - self.assertEqual(session_id, self.SESSION_NAME) - self.assertEqual(txn_id, self.TRANSACTION_ID) - self.assertEqual(mutations, transaction._mutations) - self.assertEqual( - metadata, - [ - ("google-cloud-resource-prefix", database.name), - ("x-goog-spanner-route-to-leader", "true"), - ], + expected_request = CommitRequest( + session=session.name, + transaction_id=transaction_id, + mutations=transaction._mutations, + request_options=RequestOptions(), ) + expected_metadata = [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ] - def test_context_mgr_failure(self): - from google.protobuf.empty_pb2 import Empty - - empty_pb = Empty() - from google.cloud.spanner_v1 import Transaction as TransactionPB - - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) - database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb, _rollback_response=empty_pb + commit.assert_called_once_with( + request=expected_request, + metadata=expected_metadata, ) - session = _Session(database) - transaction = self._make_one(session) + + def test_context_mgr_failure(self): + transaction = build_transaction() with self.assertRaises(Exception): with transaction: @@ -1000,74 +1087,26 @@ def test_context_mgr_failure(self): # Rollback rpc will not be called as there is no transaction id to be rolled back, rolled_back flag will be marked as true. self.assertTrue(transaction.rolled_back) self.assertEqual(len(transaction._mutations), 1) - self.assertEqual(api._committed, None) - - -class _Client(object): - def __init__(self): - from google.cloud.spanner_v1 import ExecuteSqlRequest - - self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") - self.directed_read_options = None - - -class _Instance(object): - def __init__(self): - self._client = _Client() - - -class _Database(object): - def __init__(self): - self.name = "testing" - self._instance = _Instance() - self._route_to_leader_enabled = True - self._directed_read_options = None - self.default_transaction_options = DefaultTransactionOptions() - - -class _Session(object): - _transaction = None - def __init__(self, database=None, name=TestTransaction.SESSION_NAME): - self._database = database - self.name = name + api = transaction._session._database.spanner_api + api.commit.assert_not_called() - @property - def session_id(self): - return self.name +def _build_base_attributes(database) -> dict: + """Builds and returns the base attributes for the given database.""" -class _FauxSpannerAPI(object): - _committed = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) + base_attributes = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": database.name, + "net.host.name": "spanner.googleapis.com", + "gcp.client.service": "spanner", + "gcp.client.version": LIB_VERSION, + "gcp.client.repo": "googleapis/python-spanner", + } - def begin_transaction(self, session=None, options=None, metadata=None): - self._begun = (session, options, metadata) - return self._begin_transaction_response + from tests._helpers import enrich_with_otel_scope - def rollback(self, session=None, transaction_id=None, metadata=None): - self._rolled_back = (session, transaction_id, metadata) - return self._rollback_response + enrich_with_otel_scope(base_attributes) - def commit( - self, - request=None, - metadata=None, - ): - assert not request.single_use_transaction - - max_commit_delay = None - if type(request).pb(request).HasField("max_commit_delay"): - max_commit_delay = request.max_commit_delay - - self._committed = ( - request.session, - request.mutations, - request.transaction_id, - request.request_options, - max_commit_delay, - metadata, - ) - return self._commit_response + return base_attributes