Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 529333a

Browse files
committed
More plumbing for Transaction and Database
1 parent 4a37f4c commit 529333a

File tree

4 files changed

+221
-38
lines changed

4 files changed

+221
-38
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -728,17 +728,20 @@ def execute_partitioned_dml(
728728
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
729729
)
730730

731-
nth_request = getattr(self, "_next_nth_request", 0)
732731
# Attempt will be incremented inside _restart_on_unavailable.
733-
attempt = AtomicCounter(1)
732+
begin_txn_nth_request = self._next_nth_request
733+
begin_txn_attempt = AtomicCounter(1)
734+
partial_nth_request = self._next_nth_request
735+
partial_attempt = AtomicCounter(0)
734736

735737
def execute_pdml():
736738
with SessionCheckout(self._pool) as session:
737-
all_metadata = self.metadata_with_request_id(
738-
nth_request, attempt.value, metadata
739-
)
740739
txn = api.begin_transaction(
741-
session=session.name, options=txn_options, metadata=all_metadata
740+
session=session.name,
741+
options=txn_options,
742+
metadata=self.metadata_with_request_id(
743+
begin_txn_nth_request, begin_txn_attempt.value, metadata
744+
),
742745
)
743746

744747
txn_selector = TransactionSelector(id=txn.id)
@@ -751,18 +754,24 @@ def execute_pdml():
751754
query_options=query_options,
752755
request_options=request_options,
753756
)
754-
method = functools.partial(
755-
api.execute_streaming_sql,
756-
metadata=metadata,
757-
)
757+
758+
def wrapped_method(*args, **kwargs):
759+
partial_attempt.increment()
760+
method = functools.partial(
761+
api.execute_streaming_sql,
762+
metadata=self.metadata_with_request_id(
763+
partial_nth_request, partial_attempt.value, metadata
764+
),
765+
)
766+
return method(*args, **kwargs)
758767

759768
iterator = _restart_on_unavailable(
760-
method=method,
769+
method=wrapped_method,
761770
trace_name="CloudSpanner.ExecuteStreamingSql",
762771
request=request,
763772
transaction_selector=txn_selector,
764773
observability_options=self.observability_options,
765-
attempt=attempt,
774+
attempt=begin_txn_attempt,
766775
)
767776

768777
result_set = StreamedResultSet(iterator)

google/cloud/spanner_v1/request_id_header.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ def generate_rand_uint64():
3838

3939
def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
4040
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
41-
other_metadata.append((REQ_ID_HEADER_KEY, req_id))
42-
return other_metadata
41+
all_metadata = other_metadata.copy()
42+
all_metadata.append((REQ_ID_HEADER_KEY, req_id))
43+
return all_metadata

google/cloud/spanner_v1/transaction.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.cloud.spanner_v1 import ExecuteSqlRequest
3131
from google.cloud.spanner_v1 import TransactionSelector
3232
from google.cloud.spanner_v1 import TransactionOptions
33+
from google.cloud.spanner_v1._helpers import AtomicCounter
3334
from google.cloud.spanner_v1.snapshot import _SnapshotBase
3435
from google.cloud.spanner_v1.batch import _BatchBase
3536
from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call
@@ -197,21 +198,29 @@ def rollback(self):
197198
database._route_to_leader_enabled
198199
)
199200
)
200-
all_metadata = database.metadata_with_request_id(database._next_nth_request, 1, metadata)
201+
201202
observability_options = getattr(database, "observability_options", None)
202203
with trace_call(
203204
f"CloudSpanner.{type(self).__name__}.rollback",
204205
self._session,
205206
observability_options=observability_options,
206207
):
207-
method = functools.partial(
208-
api.rollback,
209-
session=self._session.name,
210-
transaction_id=self._transaction_id,
211-
metadata=all_metadata,
212-
)
208+
attempt = AtomicCounter(0)
209+
nth_request = database._next_nth_request
210+
211+
def wrapped_method(*args, **kwargs):
212+
attempt.increment()
213+
method = functools.partial(
214+
api.rollback,
215+
session=self._session.name,
216+
transaction_id=self._transaction_id,
217+
metadata=database.metadata_with_request_id(
218+
nth_request, attempt.value, metadata
219+
),
220+
)
221+
213222
_retry(
214-
method,
223+
wrapped_method,
215224
allowed_exceptions={InternalServerError: _check_rst_stream_error},
216225
)
217226
self.rolled_back = True
@@ -286,11 +295,19 @@ def commit(
286295
) as span:
287296
add_span_event(span, "Starting Commit")
288297

289-
method = functools.partial(
290-
api.commit,
291-
request=request,
292-
metadata=database.metadata_with_request_id(database._next_nth_request, 1, metadata),
293-
)
298+
attempt = AtomicCounter(0)
299+
nth_request = database._next_nth_request
300+
301+
def wrapped_method(*args, **kwargs):
302+
attempt.increment()
303+
method = functools.partial(
304+
api.commit,
305+
request=request,
306+
metadata=database.metadata_with_request_id(
307+
nth_request, attempt.value, metadata
308+
),
309+
)
310+
return method(*args, **kwargs)
294311

295312
def beforeNextRetry(nthRetry, delayInSeconds):
296313
add_span_event(
@@ -300,7 +317,7 @@ def beforeNextRetry(nthRetry, delayInSeconds):
300317
)
301318

302319
response = _retry(
303-
method,
320+
wrapped_method,
304321
allowed_exceptions={InternalServerError: _check_rst_stream_error},
305322
beforeNextRetry=beforeNextRetry,
306323
)
@@ -434,19 +451,27 @@ def execute_update(
434451
request_options=request_options,
435452
)
436453

437-
method = functools.partial(
438-
api.execute_sql,
439-
request=request,
440-
metadata=metadata,
441-
retry=retry,
442-
timeout=timeout,
443-
)
454+
nth_request = database._next_nth_request
455+
attempt = AtomicCounter(0)
456+
457+
def wrapped_method(*args, **kwargs):
458+
attempt.increment()
459+
method = functools.partial(
460+
api.execute_sql,
461+
request=request,
462+
metadata=database.metadata_with_request_id(
463+
nth_request, attempt.value, metadata
464+
),
465+
retry=retry,
466+
timeout=timeout,
467+
)
468+
return method(*args, **kwargs)
444469

445470
if self._transaction_id is None:
446471
# lock is added to handle the inline begin for first rpc
447472
with self._lock:
448473
response = self._execute_request(
449-
method,
474+
wrapped_method,
450475
request,
451476
f"CloudSpanner.{type(self).__name__}.execute_update",
452477
self._session,
@@ -463,7 +488,7 @@ def execute_update(
463488
self._transaction_id = response.metadata.transaction.id
464489
else:
465490
response = self._execute_request(
466-
method,
491+
wrapped_method,
467492
request,
468493
f"CloudSpanner.{type(self).__name__}.execute_update",
469494
self._session,

tests/unit/test_request_id_header.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor
2323
from google.cloud.spanner_v1 import (
2424
BatchCreateSessionsRequest,
25+
BeginTransactionRequest,
2526
ExecuteSqlRequest,
2627
)
2728
from google.api_core.exceptions import Aborted
@@ -195,7 +196,7 @@ def select1():
195196
]
196197
assert got_stream_segments == want_stream_segments
197198

198-
def test_retries_on_abort(self):
199+
def test_database_run_in_transaction_retries_on_abort(self):
199200
counters = dict(aborted=0)
200201
want_failed_attempts = 2
201202

@@ -217,10 +218,157 @@ def select_in_txn(txn):
217218

218219
self.database.run_in_transaction(select_in_txn)
219220

221+
def test_database_execute_partitioned_dml_request_id(self):
222+
add_select1_result()
223+
if not getattr(self.database, "_interceptors", None):
224+
self.database._interceptors = MockServerTestBase._interceptors
225+
_ = self.database.execute_partitioned_dml("select 1")
226+
227+
requests = self.spanner_service.requests
228+
self.assertEqual(3, len(requests), msg=requests)
229+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
230+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
231+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
232+
233+
# Now ensure monotonicity of the received request-id segments.
234+
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
235+
want_unary_segments = [
236+
(
237+
"/google.spanner.v1.Spanner/BatchCreateSessions",
238+
(1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1),
239+
),
240+
(
241+
"/google.spanner.v1.Spanner/BeginTransaction",
242+
(1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1),
243+
),
244+
]
245+
want_stream_segments = [
246+
(
247+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
248+
(1, REQ_RAND_PROCESS_ID, 1, 1, 3, 1),
249+
)
250+
]
251+
252+
assert got_unary_segments == want_unary_segments
253+
assert got_stream_segments == want_stream_segments
254+
255+
def test_snapshot_read(self):
256+
add_select1_result()
257+
if not getattr(self.database, "_interceptors", None):
258+
self.database._interceptors = MockServerTestBase._interceptors
259+
with self.database.snapshot() as snapshot:
260+
results = snapshot.read("select 1")
261+
result_list = []
262+
for row in results:
263+
result_list.append(row)
264+
self.assertEqual(1, row[0])
265+
self.assertEqual(1, len(result_list))
266+
267+
requests = self.spanner_service.requests
268+
self.assertEqual(2, len(requests), msg=requests)
269+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
270+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
271+
272+
requests = self.spanner_service.requests
273+
self.assertEqual(n * 2, len(requests), msg=requests)
274+
275+
client_id = self.database._nth_client_id
276+
channel_id = self.database._channel_id
277+
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
278+
279+
want_unary_segments = [
280+
(
281+
"/google.spanner.v1.Spanner/BatchCreateSessions",
282+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
283+
),
284+
(
285+
"/google.spanner.v1.Spanner/GetSession",
286+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1),
287+
),
288+
(
289+
"/google.spanner.v1.Spanner/GetSession",
290+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1),
291+
),
292+
(
293+
"/google.spanner.v1.Spanner/GetSession",
294+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1),
295+
),
296+
(
297+
"/google.spanner.v1.Spanner/GetSession",
298+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1),
299+
),
300+
(
301+
"/google.spanner.v1.Spanner/GetSession",
302+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1),
303+
),
304+
(
305+
"/google.spanner.v1.Spanner/GetSession",
306+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1),
307+
),
308+
(
309+
"/google.spanner.v1.Spanner/GetSession",
310+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1),
311+
),
312+
(
313+
"/google.spanner.v1.Spanner/GetSession",
314+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1),
315+
),
316+
(
317+
"/google.spanner.v1.Spanner/GetSession",
318+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1),
319+
),
320+
]
321+
assert got_unary_segments == want_unary_segments
322+
323+
want_stream_segments = [
324+
(
325+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
326+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1),
327+
),
328+
(
329+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
330+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1),
331+
),
332+
(
333+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
334+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1),
335+
),
336+
(
337+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
338+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1),
339+
),
340+
(
341+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
342+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1),
343+
),
344+
(
345+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
346+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1),
347+
),
348+
(
349+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
350+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1),
351+
),
352+
(
353+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
354+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1),
355+
),
356+
(
357+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
358+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1),
359+
),
360+
(
361+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
362+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1),
363+
),
364+
]
365+
assert got_stream_segments == want_stream_segments
366+
220367
def canonicalize_request_id_headers(self):
221368
src = self.database._x_goog_request_id_interceptor
222369
return src._stream_req_segments, src._unary_req_segments
223370

371+
224372
class FauxCall:
225373
def __init__(self, code, details="FauxCall"):
226374
self._code = code

0 commit comments

Comments
 (0)