Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 65757b5

Browse files
committed
feat(x-goog-spanner-request-id): implement request_id generation and propagation
Generates a request_id that is then injected inside metadata that's sent over to the Cloud Spanner backend. Officially inject the first set of x-goog-spanner-request-id values into header metadata Add request-id interceptor to use in asserting tests Wrap Snapshot methods with x-goog-request-id metadata injector Setup scaffolding for XGoogRequestIdHeader checks Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks Inject header in more Session using spots plus more tests Base for tests with retries on abort More plumbing for Transaction and Database Update unit tests for Transaction Wrap more in Transaction + update tests Update tests Plumb in more tests Update TestDatabase Fixes #1261
1 parent f2483e1 commit 65757b5

File tree

17 files changed

+1168
-128
lines changed

17 files changed

+1168
-128
lines changed

google/cloud/spanner_v1/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from google.cloud.spanner_v1._helpers import _merge_query_options
4949
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
5050
from google.cloud.spanner_v1.instance import Instance
51+
from google.cloud.spanner_v1._helpers import AtomicCounter
5152

5253
_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
5354
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
@@ -147,6 +148,8 @@ class Client(ClientWithProject):
147148
SCOPE = (SPANNER_ADMIN_SCOPE,)
148149
"""The scopes required for Google Cloud Spanner."""
149150

151+
NTH_CLIENT = AtomicCounter()
152+
150153
def __init__(
151154
self,
152155
project=None,
@@ -199,6 +202,12 @@ def __init__(
199202
self._route_to_leader_enabled = route_to_leader_enabled
200203
self._directed_read_options = directed_read_options
201204
self._observability_options = observability_options
205+
self._nth_client_id = Client.NTH_CLIENT.increment()
206+
self._nth_request = AtomicCounter()
207+
208+
@property
209+
def _next_nth_request(self):
210+
return self._nth_request.increment()
202211

203212
@property
204213
def credentials(self):

google/cloud/spanner_v1/database.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
from google.cloud.spanner_v1 import SpannerClient
5151
from google.cloud.spanner_v1._helpers import _merge_query_options
5252
from google.cloud.spanner_v1._helpers import (
53+
AtomicCounter,
5354
_metadata_with_prefix,
5455
_metadata_with_leader_aware_routing,
56+
_metadata_with_request_id,
5557
)
5658
from google.cloud.spanner_v1.batch import Batch
5759
from google.cloud.spanner_v1.batch import MutationGroups
@@ -149,6 +151,9 @@ class Database(object):
149151

150152
_spanner_api: SpannerClient = None
151153

154+
__transport_lock = threading.Lock()
155+
__transports_to_channel_id = dict()
156+
152157
def __init__(
153158
self,
154159
database_id,
@@ -443,6 +448,31 @@ def spanner_api(self):
443448
)
444449
return self._spanner_api
445450

451+
@property
452+
def _channel_id(self):
453+
"""
454+
Helper to retrieve the associated channelID for the spanner_api.
455+
This property is paramount to x-goog-spanner-request-id.
456+
"""
457+
with self.__transport_lock:
458+
api = self.spanner_api
459+
channel_id = self.__transports_to_channel_id.get(api._transport, None)
460+
if channel_id is None:
461+
channel_id = len(self.__transports_to_channel_id) + 1
462+
self.__transports_to_channel_id[api._transport] = channel_id
463+
464+
return channel_id
465+
466+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
467+
client_id = self._nth_client_id
468+
return _metadata_with_request_id(
469+
self._nth_client_id,
470+
self._channel_id,
471+
nth_request,
472+
nth_attempt,
473+
prior_metadata,
474+
)
475+
446476
def __eq__(self, other):
447477
if not isinstance(other, self.__class__):
448478
return NotImplemented
@@ -698,10 +728,20 @@ def execute_partitioned_dml(
698728
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
699729
)
700730

731+
# Attempt will be incremented inside _restart_on_unavailable.
732+
begin_txn_nth_request = self._next_nth_request
733+
begin_txn_attempt = AtomicCounter(1)
734+
partial_nth_request = self._next_nth_request
735+
partial_attempt = AtomicCounter(0)
736+
701737
def execute_pdml():
702738
with SessionCheckout(self._pool) as session:
703739
txn = api.begin_transaction(
704-
session=session.name, options=txn_options, metadata=metadata
740+
session=session.name,
741+
options=txn_options,
742+
metadata=self.metadata_with_request_id(
743+
begin_txn_nth_request, begin_txn_attempt.value, metadata
744+
),
705745
)
706746

707747
txn_selector = TransactionSelector(id=txn.id)
@@ -714,17 +754,24 @@ def execute_pdml():
714754
query_options=query_options,
715755
request_options=request_options,
716756
)
717-
method = functools.partial(
718-
api.execute_streaming_sql,
719-
metadata=metadata,
720-
)
757+
758+
def wrapped_method(*args, **kwargs):
759+
partial_attempt.increment()
760+
method = functools.partial(
761+
api.execute_streaming_sql,
762+
metadata=self.metadata_with_request_id(
763+
partial_nth_request, partial_attempt.value, metadata
764+
),
765+
)
766+
return method(*args, **kwargs)
721767

722768
iterator = _restart_on_unavailable(
723-
method=method,
769+
method=wrapped_method,
724770
trace_name="CloudSpanner.ExecuteStreamingSql",
725771
request=request,
726772
transaction_selector=txn_selector,
727773
observability_options=self.observability_options,
774+
attempt=begin_txn_attempt,
728775
)
729776

730777
result_set = StreamedResultSet(iterator)
@@ -734,6 +781,14 @@ def execute_pdml():
734781

735782
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
736783

784+
@property
785+
def _next_nth_request(self):
786+
return self._instance._client._next_nth_request
787+
788+
@property
789+
def _nth_client_id(self):
790+
return self._instance._client._nth_client_id
791+
737792
def session(self, labels=None, database_role=None):
738793
"""Factory to create a session for this database.
739794

google/cloud/spanner_v1/instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def database(
501501
proto_descriptors=proto_descriptors,
502502
)
503503
else:
504+
print("enabled interceptors")
504505
return TestDatabase(
505506
database_id,
506507
self,

google/cloud/spanner_v1/pool.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def bind(self, database):
243243
"CloudSpanner.FixedPool.BatchCreateSessions",
244244
observability_options=observability_options,
245245
) as span:
246+
attempt = 1
246247
returned_session_count = 0
247248
while not self._sessions.full():
248249
request.session_count = requested_session_count - self._sessions.qsize()
@@ -251,9 +252,12 @@ def bind(self, database):
251252
f"Creating {request.session_count} sessions",
252253
span_event_attributes,
253254
)
255+
all_metadata = database.metadata_with_request_id(
256+
database._next_nth_request, attempt, metadata
257+
)
254258
resp = api.batch_create_sessions(
255259
request=request,
256-
metadata=metadata,
260+
metadata=all_metadata,
257261
)
258262

259263
add_span_event(

google/cloud/spanner_v1/session.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def exists(self):
193193
current_span, "Checking if Session exists", {"session.id": self._session_id}
194194
)
195195

196-
api = self._database.spanner_api
196+
database = self._database
197+
api = database.spanner_api
197198
metadata = _metadata_with_prefix(self._database.name)
198199
if self._database._route_to_leader_enabled:
199200
metadata.append(
@@ -202,12 +203,16 @@ def exists(self):
202203
)
203204
)
204205

206+
all_metadata = database.metadata_with_request_id(
207+
database._next_nth_request, 1, metadata
208+
)
209+
205210
observability_options = getattr(self._database, "observability_options", None)
206211
with trace_call(
207212
"CloudSpanner.GetSession", self, observability_options=observability_options
208213
) as span:
209214
try:
210-
api.get_session(name=self.name, metadata=metadata)
215+
api.get_session(name=self.name, metadata=all_metadata)
211216
if span:
212217
span.set_attribute("session_found", True)
213218
except NotFound:
@@ -237,8 +242,11 @@ def delete(self):
237242
current_span, "Deleting Session", {"session.id": self._session_id}
238243
)
239244

240-
api = self._database.spanner_api
241-
metadata = _metadata_with_prefix(self._database.name)
245+
database = self._database
246+
api = database.spanner_api
247+
metadata = database.metadata_with_request_id(
248+
database._next_nth_request, 1, _metadata_with_prefix(database.name)
249+
)
242250
observability_options = getattr(self._database, "observability_options", None)
243251
with trace_call(
244252
"CloudSpanner.DeleteSession",
@@ -259,7 +267,10 @@ def ping(self):
259267
if self._session_id is None:
260268
raise ValueError("Session ID not set by back-end")
261269
api = self._database.spanner_api
262-
metadata = _metadata_with_prefix(self._database.name)
270+
database = self._database
271+
metadata = database.metadata_with_request_id(
272+
database._next_nth_request, 1, _metadata_with_prefix(database.name)
273+
)
263274
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
264275
api.execute_sql(request=request, metadata=metadata)
265276
self._last_use_time = datetime.now()

0 commit comments

Comments
 (0)