Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit df8f81f

Browse files
committed
Correctly handle wrapping by class for api objects
1 parent 0495082 commit df8f81f

File tree

8 files changed

+118
-84
lines changed

8 files changed

+118
-84
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -737,27 +737,73 @@ class InterceptingHeaderInjector:
737737
def __init__(self, original_callable: Callable):
738738
self._original_callable = original_callable
739739

740-
def __call__(self, *args, **kwargs):
741-
metadata = kwargs.get("metadata", [])
742-
# Find all the headers that match the x-goog-spanner-request-id
743-
# header an on each retry increment the value.
744-
all_metadata = []
745-
for key, value in metadata:
746-
if key is REQ_ID_HEADER_KEY:
747-
# Otherwise now increment the count for the attempt number.
748-
splits = value.split(".")
749-
attempt_plus_one = int(splits[-1]) + 1
750-
splits[-1] = str(attempt_plus_one)
751-
value_before = value
752-
value = ".".join(splits)
753-
print("incrementing value on retry from", value_before, "to", value)
754-
755-
all_metadata.append(
756-
(
757-
key,
758-
value,
759-
)
760-
)
761740

762-
kwargs["metadata"] = all_metadata
763-
return self._original_callable(*args, **kwargs)
741+
patched = {}
742+
743+
744+
def inject_retry_header_control(api):
745+
# For each method, add an _attempt value that'll then be
746+
# retrieved for each retry.
747+
# 1. Patch the __getattribute__ method to match items in our manifest.
748+
target = type(api)
749+
hex_id = hex(id(target))
750+
if patched.get(hex_id, None) is not None:
751+
return
752+
753+
orig_getattribute = getattr(target, "__getattribute__")
754+
755+
def patched_getattribute(*args, **kwargs):
756+
attr = orig_getattribute(*args, **kwargs)
757+
758+
# 0. If we already patched it, we can return immediately.
759+
if getattr(attr, "_patched", None) is not None:
760+
return attr
761+
762+
# 1. Skip over non-methods.
763+
if not callable(attr):
764+
return attr
765+
766+
# 2. Skip modifying private and mangled methods.
767+
mangled_or_private = attr.__name__.startswith("_")
768+
if mangled_or_private:
769+
return attr
770+
771+
print("\033[35mattr", attr, "hex_id", hex(id(attr)), "\033[00m")
772+
773+
# 3. Wrap the callable attribute and then capture its metadata keyed argument.
774+
def wrapped_attr(*args, **kwargs):
775+
metadata = kwargs.get("metadata", [])
776+
if not metadata:
777+
# Increment the reinvocation count.
778+
print("not metatadata", attr.__name__)
779+
wrapped_attr._attempt += 1
780+
return attr(*args, **kwargs)
781+
782+
# 4. Find all the headers that match the target header key.
783+
all_metadata = []
784+
for key, value in metadata:
785+
if key is REQ_ID_HEADER_KEY:
786+
print("key", key, "value", value, "attempt", wrapped_attr._attempt)
787+
# 5. Increment the original_attempt with that of our re-invocation count.
788+
splits = value.split(".")
789+
hdr_attempt_plus_reinvocation = (
790+
int(splits[-1]) + wrapped_attr._attempt
791+
)
792+
splits[-1] = str(hdr_attempt_plus_reinvocation)
793+
value = ".".join(splits)
794+
795+
all_metadata.append((key, value))
796+
797+
# Increment the reinvocation count.
798+
wrapped_attr._attempt += 1
799+
800+
kwargs["metadata"] = all_metadata
801+
print("\033[34mwrap_callable", hex(id(attr)), attr.__name__, "\033[00m")
802+
return attr(*args, **kwargs)
803+
804+
wrapped_attr._attempt = 0
805+
wrapped_attr._patched = True
806+
return wrapped_attr
807+
808+
setattr(target, "__getattribute__", patched_getattribute)
809+
patched[hex_id] = True

google/cloud/spanner_v1/database.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
_metadata_with_prefix,
5656
_metadata_with_leader_aware_routing,
5757
_metadata_with_request_id,
58-
InterceptingHeaderInjector,
58+
inject_retry_header_control,
5959
)
6060
from google.cloud.spanner_v1.batch import Batch
6161
from google.cloud.spanner_v1.batch import MutationGroups
@@ -434,42 +434,14 @@ def logger(self):
434434
@property
435435
def spanner_api(self):
436436
"""Helper for session-related API calls."""
437-
api = self._generate_spanner_api()
437+
api = self.__generate_spanner_api()
438438
if not api:
439439
return api
440440

441-
# Now wrap each method's __call__ method with our wrapped one.
442-
# This is how to deal with the fact that there are no proper gRPC
443-
# interceptors for Python hence the remedy is to replace callables
444-
# with our custom wrapper.
445-
attrs = dir(api)
446-
for attr_name in attrs:
447-
mangled = attr_name.startswith("__")
448-
if mangled:
449-
continue
450-
451-
non_public = attr_name.startswith("_")
452-
if non_public:
453-
continue
454-
455-
attr = getattr(api, attr_name)
456-
callable_attr = callable(attr)
457-
if callable_attr is None:
458-
continue
459-
460-
# We should only be looking at bound methods to SpannerClient
461-
# as those are the RPC invoking methods that need to be wrapped
462-
463-
is_method = type(attr).__name__ == "method"
464-
if not is_method:
465-
continue
466-
467-
print("attr_name", attr_name, "callable_attr", attr)
468-
setattr(api, attr_name, InterceptingHeaderInjector(attr))
469-
441+
inject_retry_header_control(api)
470442
return api
471443

472-
def _generate_spanner_api(self):
444+
def __generate_spanner_api(self):
473445
"""Helper for session-related API calls."""
474446
if self._spanner_api is None:
475447
client_info = self._instance._client._client_info
@@ -804,7 +776,9 @@ def execute_pdml():
804776
session=session.name,
805777
options=txn_options,
806778
metadata=self.metadata_with_request_id(
807-
begin_txn_nth_request, begin_txn_attempt.increment(), metadata
779+
begin_txn_nth_request,
780+
begin_txn_attempt.increment(),
781+
metadata,
808782
),
809783
)
810784

google/cloud/spanner_v1/pool.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def create_sessions(attempt):
268268
metadata=all_metadata,
269269
)
270270

271-
resp = retry_on_unavailable(create_sessions)
271+
resp = retry_on_unavailable(create_sessions, "fixedpool")
272+
# print("resp.FixedPool", resp)
272273

273274
add_span_event(
274275
span,
@@ -581,7 +582,8 @@ def create_sessions(attempt):
581582
metadata=all_metadata,
582583
)
583584

584-
resp = retry_on_unavailable(create_sessions)
585+
resp = retry_on_unavailable(create_sessions, "pingpool")
586+
print("resp.PingingPool", resp)
585587

586588
add_span_event(
587589
span,
@@ -822,20 +824,23 @@ def __exit__(self, *ignored):
822824
self._pool.put(self._session)
823825

824826

825-
def retry_on_unavailable(fn, max=6):
827+
def retry_on_unavailable(fn, kind, max=6):
826828
"""
827829
Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions,
828830
each time passing in the iteration's ordinal number to signal
829831
the nth attempt. It retries with exponential backoff with jitter.
830832
"""
831833
last_exc = None
832834
for i in range(max):
835+
print("retry_on_unavailable", kind, i)
833836
try:
834837
return fn(i + 1)
835838
except ServiceUnavailable as exc:
839+
print("exc", exc)
836840
last_exc = exc
837841
time.sleep(i**2 + random.random())
838-
except:
842+
except Exception as e:
843+
print("got exception", e)
839844
raise
840845

841846
raise last_exc

google/cloud/spanner_v1/services/spanner/transports/grpc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ def batch_create_sessions(
413413
request_serializer=spanner.BatchCreateSessionsRequest.serialize,
414414
response_deserializer=spanner.BatchCreateSessionsResponse.deserialize,
415415
)
416-
return self._stubs["batch_create_sessions"]
416+
fn = self._stubs["batch_create_sessions"]
417+
print("\033[32minvoking batch_create_sessionhex_id", hex(id(fn)), "\033[00m")
418+
return fn
417419

418420
@property
419421
def get_session(self) -> Callable[[spanner.GetSessionRequest], spanner.Session]:

google/cloud/spanner_v1/snapshot.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,7 @@ def read(
332332
)
333333

334334
nth_request = getattr(database, "_next_nth_request", 0)
335-
all_metadata = database.metadata_with_request_id(
336-
nth_request, 1, metadata
337-
)
335+
all_metadata = database.metadata_with_request_id(nth_request, 1, metadata)
338336

339337
restart = functools.partial(
340338
api.streaming_read,
@@ -574,15 +572,19 @@ def execute_sql(
574572
if not isinstance(nth_request, int):
575573
raise Exception(f"failed to get an integer back: {nth_request}")
576574

577-
restart = functools.partial(
578-
api.execute_streaming_sql,
579-
request=request,
580-
metadata=database.metadata_with_request_id(
581-
nth_request, 1, metadata
582-
),
583-
retry=retry,
584-
timeout=timeout,
585-
)
575+
attempt = AtomicCounter(0)
576+
577+
def wrapped_restart(*args, **kwargs):
578+
restart = functools.partial(
579+
api.execute_streaming_sql,
580+
request=request,
581+
metadata=database.metadata_with_request_id(
582+
nth_request, attempt.increment(), metadata
583+
),
584+
retry=retry,
585+
timeout=timeout,
586+
)
587+
return restart(*args, **kwargs)
586588

587589
trace_attributes = {"db.statement": sql}
588590
observability_options = getattr(database, "observability_options", None)
@@ -591,7 +593,7 @@ def execute_sql(
591593
# lock is added to handle the inline begin for first rpc
592594
with self._lock:
593595
return self._get_streamed_result_set(
594-
restart,
596+
wrapped_restart,
595597
request,
596598
metadata,
597599
trace_attributes,
@@ -601,7 +603,7 @@ def execute_sql(
601603
)
602604
else:
603605
return self._get_streamed_result_set(
604-
restart,
606+
wrapped_restart,
605607
request,
606608
metadata,
607609
trace_attributes,
@@ -733,9 +735,7 @@ def partition_read(
733735
metadata=metadata,
734736
), MetricsCapture():
735737
nth_request = getattr(database, "_next_nth_request", 0)
736-
all_metadata = database.metadata_with_request_id(
737-
nth_request, 1, metadata
738-
)
738+
all_metadata = database.metadata_with_request_id(nth_request, 1, metadata)
739739
method = functools.partial(
740740
api.partition_read,
741741
request=request,
@@ -842,9 +842,7 @@ def partition_query(
842842
metadata=metadata,
843843
), MetricsCapture():
844844
nth_request = getattr(database, "_next_nth_request", 0)
845-
all_metadata = database.metadata_with_request_id(
846-
nth_request, 1, metadata
847-
)
845+
all_metadata = database.metadata_with_request_id(nth_request, 1, metadata)
848846
method = functools.partial(
849847
api.partition_query,
850848
request=request,
@@ -994,9 +992,7 @@ def begin(self):
994992
metadata=metadata,
995993
), MetricsCapture():
996994
nth_request = getattr(database, "_next_nth_request", 0)
997-
all_metadata = database.metadata_with_request_id(
998-
nth_request, 1, metadata
999-
)
995+
all_metadata = database.metadata_with_request_id(nth_request, 1, metadata)
1000996
method = functools.partial(
1001997
api.begin_transaction,
1002998
session=self._session.name,

google/cloud/spanner_v1/testing/database_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MethodAbortInterceptor,
2828
XGoogRequestIDHeaderInterceptor,
2929
)
30+
from google.cloud.spanner_v1._helpers import inject_retry_header_control
3031

3132

3233
class TestDatabase(Database):
@@ -70,6 +71,14 @@ def __init__(
7071

7172
@property
7273
def spanner_api(self):
74+
api = self.__generate_spanner_api()
75+
if not api:
76+
return api
77+
78+
inject_retry_header_control(api)
79+
return api
80+
81+
def __generate_spanner_api(self):
7382
"""Helper for session-related API calls."""
7483
if self._spanner_api is None:
7584
client = self._instance._client

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def pop_error(self, context):
5353
name = inspect.currentframe().f_back.f_code.co_name
5454
error: _Status | None = self.errors.pop(name, None)
5555
if error:
56+
print("context.abort_with_status", error)
5657
context.abort_with_status(error)
5758

5859
def get_result_as_partial_result_sets(

tests/mockserver_tests/test_request_id_header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def test_unary_retryable_error(self):
309309
)
310310
]
311311

312+
print("got_unaries", got_unary_segments)
312313
assert got_unary_segments == want_unary_segments
313314
assert got_stream_segments == want_stream_segments
314315

0 commit comments

Comments
 (0)