Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 0495082

Browse files
committed
Implement interceptor to wrap and increase x-goog-spanner-request-id attempts per retry
This monkey patches SpannerClient methods to have an interceptor that increases the attempts per retry. The prelude though is that any callers to it must pass in the attempt value 0 so that each pass through will correctly increase the attempt field's value.
1 parent 9ce98c3 commit 0495082

File tree

6 files changed

+159
-147
lines changed

6 files changed

+159
-147
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from google.cloud.spanner_v1 import ExecuteSqlRequest
3434
from google.cloud.spanner_v1 import JsonObject
3535
from google.cloud.spanner_v1 import TransactionOptions
36-
from google.cloud.spanner_v1.request_id_header import with_request_id
36+
from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id
3737
from google.rpc.error_details_pb2 import RetryInfo
3838

3939
try:
@@ -45,6 +45,7 @@
4545
HAS_OPENTELEMETRY_INSTALLED = False
4646
from typing import List, Tuple
4747
import random
48+
from typing import Callable
4849

4950
# Validation error messages
5051
NUMERIC_MAX_SCALE_ERR_MSG = (
@@ -730,3 +731,33 @@ def _merge_Transaction_Options(
730731

731732
# Convert protobuf object back into a TransactionOptions instance
732733
return TransactionOptions(merged_pb)
734+
735+
736+
class InterceptingHeaderInjector:
737+
def __init__(self, original_callable: Callable):
738+
self._original_callable = original_callable
739+
740+
def __call__(self, *args, **kwargs):
741+
metadata = kwargs.get("metadata", [])
742+
# Find all the headers that match the x-goog-spanner-request-id
743+
# header an on each retry increment the value.
744+
all_metadata = []
745+
for key, value in metadata:
746+
if key is REQ_ID_HEADER_KEY:
747+
# Otherwise now increment the count for the attempt number.
748+
splits = value.split(".")
749+
attempt_plus_one = int(splits[-1]) + 1
750+
splits[-1] = str(attempt_plus_one)
751+
value_before = value
752+
value = ".".join(splits)
753+
print("incrementing value on retry from", value_before, "to", value)
754+
755+
all_metadata.append(
756+
(
757+
key,
758+
value,
759+
)
760+
)
761+
762+
kwargs["metadata"] = all_metadata
763+
return self._original_callable(*args, **kwargs)

google/cloud/spanner_v1/batch.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,16 @@ def commit(
253253
attempt = AtomicCounter(0)
254254
next_nth_request = database._next_nth_request
255255

256-
def wrapped_method(*args, **kwargs):
257-
all_metadata = database.metadata_with_request_id(
258-
next_nth_request,
259-
attempt.increment(),
260-
metadata,
261-
)
262-
method = functools.partial(
263-
api.commit,
264-
request=request,
265-
metadata=all_metadata,
266-
)
267-
return method(*args, **kwargs)
256+
all_metadata = database.metadata_with_request_id(
257+
next_nth_request,
258+
attempt.increment(),
259+
metadata,
260+
)
261+
method = functools.partial(
262+
api.commit,
263+
request=request,
264+
metadata=all_metadata,
265+
)
268266

269267
deadline = time.time() + kwargs.get(
270268
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
@@ -384,21 +382,17 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
384382
observability_options=observability_options,
385383
metadata=metadata,
386384
), MetricsCapture():
387-
attempt = AtomicCounter(0)
388385
next_nth_request = database._next_nth_request
389-
390-
def wrapped_method(*args, **kwargs):
391-
all_metadata = database.metadata_with_request_id(
392-
next_nth_request,
393-
attempt.increment(),
394-
metadata,
395-
)
396-
method = functools.partial(
397-
api.batch_write,
398-
request=request,
399-
metadata=all_metadata,
400-
)
401-
return method(*args, **kwargs)
386+
all_metadata = database.metadata_with_request_id(
387+
next_nth_request,
388+
0,
389+
metadata,
390+
)
391+
method = functools.partial(
392+
api.batch_write,
393+
request=request,
394+
metadata=all_metadata,
395+
)
402396

403397
response = _retry(
404398
method,

google/cloud/spanner_v1/database.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_metadata_with_prefix,
5656
_metadata_with_leader_aware_routing,
5757
_metadata_with_request_id,
58+
InterceptingHeaderInjector,
5859
)
5960
from google.cloud.spanner_v1.batch import Batch
6061
from google.cloud.spanner_v1.batch import MutationGroups
@@ -432,6 +433,43 @@ def logger(self):
432433

433434
@property
434435
def spanner_api(self):
436+
"""Helper for session-related API calls."""
437+
api = self._generate_spanner_api()
438+
if not api:
439+
return api
440+
441+
# Now wrap each method's __call__ method with our wrapped one.
442+
# This is how to deal with the fact that there are no proper gRPC
443+
# interceptors for Python hence the remedy is to replace callables
444+
# with our custom wrapper.
445+
attrs = dir(api)
446+
for attr_name in attrs:
447+
mangled = attr_name.startswith("__")
448+
if mangled:
449+
continue
450+
451+
non_public = attr_name.startswith("_")
452+
if non_public:
453+
continue
454+
455+
attr = getattr(api, attr_name)
456+
callable_attr = callable(attr)
457+
if callable_attr is None:
458+
continue
459+
460+
# We should only be looking at bound methods to SpannerClient
461+
# as those are the RPC invoking methods that need to be wrapped
462+
463+
is_method = type(attr).__name__ == "method"
464+
if not is_method:
465+
continue
466+
467+
print("attr_name", attr_name, "callable_attr", attr)
468+
setattr(api, attr_name, InterceptingHeaderInjector(attr))
469+
470+
return api
471+
472+
def _generate_spanner_api(self):
435473
"""Helper for session-related API calls."""
436474
if self._spanner_api is None:
437475
client_info = self._instance._client._client_info
@@ -762,11 +800,11 @@ def execute_pdml():
762800
) as span, MetricsCapture():
763801
with SessionCheckout(self._pool) as session:
764802
add_span_event(span, "Starting BeginTransaction")
765-
begin_txn_attempt.increment()
766803
txn = api.begin_transaction(
767-
session=session.name, options=txn_options,
804+
session=session.name,
805+
options=txn_options,
768806
metadata=self.metadata_with_request_id(
769-
begin_txn_nth_request, begin_txn_attempt.value, metadata
807+
begin_txn_nth_request, begin_txn_attempt.increment(), metadata
770808
),
771809
)
772810

@@ -781,37 +819,21 @@ def execute_pdml():
781819
request_options=request_options,
782820
)
783821

784-
def wrapped_method(*args, **kwargs):
785-
partial_attempt.increment()
786-
method = functools.partial(
787-
api.execute_streaming_sql,
788-
metadata=self.metadata_with_request_id(
789-
partial_nth_request, partial_attempt.value, metadata
790-
),
791-
)
792-
return method(*args, **kwargs)
822+
method = functools.partial(
823+
api.execute_streaming_sql,
824+
metadata=self.metadata_with_request_id(
825+
partial_nth_request, partial_attempt.increment(), metadata
826+
),
827+
)
793828

794829
iterator = _restart_on_unavailable(
795-
method=wrapped_method,
830+
method=method,
796831
trace_name="CloudSpanner.ExecuteStreamingSql",
797832
request=request,
798833
metadata=metadata,
799834
transaction_selector=txn_selector,
800835
observability_options=self.observability_options,
801-
attempt=begin_txn_attempt,
802836
)
803-
<<<<<<< HEAD
804-
=======
805-
return method(*args, **kwargs)
806-
807-
iterator = _restart_on_unavailable(
808-
method=wrapped_method,
809-
trace_name="CloudSpanner.ExecuteStreamingSql",
810-
request=request,
811-
transaction_selector=txn_selector,
812-
observability_options=self.observability_options,
813-
)
814-
>>>>>>> 54df502... Update tests
815837

816838
result_set = StreamedResultSet(iterator)
817839
list(result_set) # consume all partials

google/cloud/spanner_v1/pool.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,6 @@ def create_sessions(attempt):
266266
return api.batch_create_sessions(
267267
request=request,
268268
metadata=all_metadata,
269-
# Manually passing retry=None because otherwise any
270-
# UNAVAILABLE retry will be retried without replenishing
271-
# the metadata, hence this allows us to manually update
272-
# the metadata using retry_on_unavailable.
273-
retry=None,
274269
)
275270

276271
resp = retry_on_unavailable(create_sessions)
@@ -584,13 +579,6 @@ def create_sessions(attempt):
584579
return api.batch_create_sessions(
585580
request=request,
586581
metadata=all_metadata,
587-
# Manually passing retry=None because otherwise any
588-
# UNAVAILABLE retry will be retried without replenishing
589-
# the metadata, hence this allows us to manually update
590-
# the metadata using retry_on_unavailable.
591-
# TODO: Figure out how to intercept and monkey patch the internals
592-
# of the gRPC transport.
593-
retry=None,
594582
)
595583

596584
resp = retry_on_unavailable(create_sessions)

0 commit comments

Comments
 (0)