Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 64bd582

Browse files
committed
chore(x-goog-spanner-request-id): more updates for batch_write + mockserver tests
This change plumbs in some x-goog-spanner-request-id updates for batch_write and some tests too. Updates #1261
1 parent 3a91671 commit 64bd582

File tree

10 files changed

+545
-66
lines changed

10 files changed

+545
-66
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,20 @@ def _check_rst_stream_error(exc):
587587
raise
588588

589589

590+
def _check_unavailable(exc):
591+
resumable_error = (
592+
any(
593+
resumable_message in exc.message
594+
for resumable_message in (
595+
"INTERNAL",
596+
"Service unavailable",
597+
)
598+
),
599+
)
600+
if not resumable_error:
601+
raise
602+
603+
590604
def _metadata_with_leader_aware_routing(value, **kw):
591605
"""Create RPC metadata containing a leader aware routing header
592606

google/cloud/spanner_v1/batch.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_metadata_with_prefix,
2727
_metadata_with_leader_aware_routing,
2828
_merge_Transaction_Options,
29+
AtomicCounter,
2930
)
3031
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
3132
from google.cloud.spanner_v1 import RequestOptions
@@ -385,13 +386,22 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
385386
observability_options=observability_options,
386387
metadata=metadata,
387388
), MetricsCapture():
388-
method = functools.partial(
389-
api.batch_write,
390-
request=request,
391-
metadata=metadata,
392-
)
389+
attempt = AtomicCounter(0)
390+
nth_request = database._next_nth_request
391+
392+
def wrapped_method(*args, **kwargs):
393+
return functools.partial(
394+
api.batch_write,
395+
request=request,
396+
metadata=database.metadata_with_request_id(
397+
nth_request,
398+
attempt.increment(),
399+
metadata,
400+
),
401+
)(*args, **kwargs)
402+
393403
response = _retry(
394-
method,
404+
wrapped_method,
395405
allowed_exceptions={
396406
InternalServerError: _check_rst_stream_error,
397407
},

google/cloud/spanner_v1/pool.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,22 @@
1515
"""Pools managing shared Session objects."""
1616

1717
import datetime
18+
import functools
1819
import queue
1920
import time
2021

22+
from google.api_core.exceptions import InternalServerError
23+
from google.api_core.exceptions import ServiceUnavailable
2124
from google.cloud.exceptions import NotFound
2225
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
2326
from google.cloud.spanner_v1 import Session
2427
from google.cloud.spanner_v1._helpers import (
28+
_check_rst_stream_error,
29+
_check_unavailable,
2530
_metadata_with_prefix,
2631
_metadata_with_leader_aware_routing,
32+
_retry,
33+
AtomicCounter,
2734
)
2835
from google.cloud.spanner_v1._opentelemetry_tracing import (
2936
add_span_event,
@@ -254,11 +261,25 @@ def bind(self, database):
254261
f"Creating {request.session_count} sessions",
255262
span_event_attributes,
256263
)
257-
resp = api.batch_create_sessions(
258-
request=request,
259-
metadata=database.metadata_with_request_id(
260-
database._next_nth_request, 1, metadata
261-
),
264+
attempt = AtomicCounter(0)
265+
nth_request = database._next_nth_request
266+
267+
def wrapped_method(*args, **kwargs):
268+
method = functools.partial(
269+
api.batch_create_sessions,
270+
request=request,
271+
metadata=database.metadata_with_request_id(
272+
nth_request, attempt.increment(), metadata
273+
),
274+
)
275+
return method(*args, **kwargs)
276+
277+
resp = _retry(
278+
wrapped_method,
279+
allowed_exceptions={
280+
InternalServerError: _check_rst_stream_error,
281+
ServiceUnavailable: _check_unavailable,
282+
},
262283
)
263284

264285
add_span_event(
@@ -561,11 +582,23 @@ def bind(self, database):
561582
) as span, MetricsCapture():
562583
returned_session_count = 0
563584
while returned_session_count < self.size:
564-
resp = api.batch_create_sessions(
565-
request=request,
566-
metadata=database.metadata_with_request_id(
567-
database._next_nth_request, 1, metadata
568-
),
585+
attempt = AtomicCounter(0)
586+
nth_request = database._next_nth_request
587+
588+
def wrapped_method(*args, **kwargs):
589+
return api.batch_create_sessions(
590+
request=request,
591+
metadata=database.metadata_with_request_id(
592+
database._next_nth_request, attempt.increment(), metadata
593+
),
594+
)
595+
596+
resp = _retry(
597+
wrapped_method,
598+
allowed_exceptions={
599+
InternalServerError: _check_rst_stream_error,
600+
ServiceUnavailable: _check_unavailable,
601+
},
569602
)
570603

571604
add_span_event(

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def reset(self):
7171

7272

7373
class XGoogRequestIDHeaderInterceptor(ClientInterceptor):
74-
# TODO:(@odeke-em): delete this guard when PR #1367 is merged.
75-
X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED = True
76-
7774
def __init__(self):
7875
self._unary_req_segments = []
7976
self._stream_req_segments = []
@@ -87,24 +84,23 @@ def intercept(self, method, request_or_iterator, call_details):
8784
x_goog_request_id = value
8885
break
8986

90-
if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED and not x_goog_request_id:
87+
if not x_goog_request_id:
9188
raise Exception(
9289
f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}"
9390
)
9491

9592
response_or_iterator = method(request_or_iterator, call_details)
9693
streaming = getattr(response_or_iterator, "__iter__", None) is not None
9794

98-
if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED:
99-
with self.__lock:
100-
if streaming:
101-
self._stream_req_segments.append(
102-
(call_details.method, parse_request_id(x_goog_request_id))
103-
)
104-
else:
105-
self._unary_req_segments.append(
106-
(call_details.method, parse_request_id(x_goog_request_id))
107-
)
95+
with self.__lock:
96+
if streaming:
97+
self._stream_req_segments.append(
98+
(call_details.method, parse_request_id(x_goog_request_id))
99+
)
100+
else:
101+
self._unary_req_segments.append(
102+
(call_details.method, parse_request_id(x_goog_request_id))
103+
)
108104

109105
return response_or_iterator
110106

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def aborted_status() -> _Status:
6161
def unavailable_status() -> _Status:
6262
error = status_pb2.Status(
6363
code=code_pb2.UNAVAILABLE,
64+
message="Received unexpected EOS on DATA frame from server",
65+
)
66+
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))
67+
status = _Status(
68+
code=code_to_grpc_status_code(error.code),
69+
details=error.message,
70+
trailing_metadata=(
71+
("grpc-status-details-bin", error.SerializeToString()),
72+
(
73+
"google.rpc.retryinfo-bin",
74+
retry_info.SerializeToString(),
75+
),
76+
),
77+
)
78+
return status
79+
80+
81+
# Creates an INTERNAL status with the smallest possible retry delay.
82+
def internal_status() -> _Status:
83+
error = status_pb2.Status(
84+
code=code_pb2.INTERNAL,
6485
message="Service unavailable.",
6586
)
6687
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))

0 commit comments

Comments
 (0)