Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit fc6a374

Browse files
committed
Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks
1 parent f5459e5 commit fc6a374

File tree

9 files changed

+115
-91
lines changed

9 files changed

+115
-91
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ class Database(object):
151151

152152
_spanner_api: SpannerClient = None
153153

154+
__transport_lock = threading.Lock()
155+
__transports_to_channel_id = dict()
156+
154157
def __init__(
155158
self,
156159
database_id,
@@ -445,6 +448,31 @@ def spanner_api(self):
445448
)
446449
return self._spanner_api
447450

451+
@property
452+
def _channel_id(self):
453+
"""
454+
Helper to retrieve the associated channelID for the spanner_api.
455+
This property is paramount to x-goog-spanner-request-id.
456+
"""
457+
with self.__transport_lock:
458+
api = self.spanner_api
459+
channel_id = self.__transports_to_channel_id.get(api._transport, None)
460+
if channel_id is None:
461+
channel_id = len(self.__transports_to_channel_id) + 1
462+
self.__transports_to_channel_id[api._transport] = channel_id
463+
464+
return channel_id
465+
466+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
467+
client_id = self._nth_client_id
468+
return _metadata_with_request_id(
469+
self._nth_client_id,
470+
self._channel_id,
471+
nth_request,
472+
nth_attempt,
473+
prior_metadata,
474+
)
475+
448476
def __eq__(self, other):
449477
if not isinstance(other, self.__class__):
450478
return NotImplemented
@@ -706,10 +734,8 @@ def execute_partitioned_dml(
706734

707735
def execute_pdml():
708736
with SessionCheckout(self._pool) as session:
709-
channel_id = getattr(session, "_channel_id", 0)
710-
client_id = getattr(self, "_nth_client_id", 0)
711-
all_metadata = _metadata_with_request_id(
712-
client_id, channel_id, nth_request, attempt.value, metadata
737+
all_metadata = self.metadata_with_request_id(
738+
nth_request, attempt.value, metadata
713739
)
714740
txn = api.begin_transaction(
715741
session=session.name, options=txn_options, metadata=all_metadata

google/cloud/spanner_v1/instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def database(
501501
proto_descriptors=proto_descriptors,
502502
)
503503
else:
504+
print("enabled interceptors")
504505
return TestDatabase(
505506
database_id,
506507
self,

google/cloud/spanner_v1/pool.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import datetime
1818
import queue
1919
import time
20-
import threading
2120

2221
from google.cloud.exceptions import NotFound
2322
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
@@ -54,8 +53,6 @@ def __init__(self, labels=None, database_role=None):
5453
labels = {}
5554
self._labels = labels
5655
self._database_role = database_role
57-
self.__lock = threading.Lock()
58-
self._session_id_to_channel_id = dict()
5956

6057
@property
6158
def labels(self):
@@ -131,19 +128,10 @@ def _new_session(self):
131128
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
132129
:returns: new session instance.
133130
"""
134-
session = self._database.session(
131+
return self._database.session(
135132
labels=self.labels, database_role=self.database_role
136133
)
137134

138-
session_id = getattr(session, "_session_id", None)
139-
if session_id:
140-
with self.__lock:
141-
channel_id = len(self._session_id_to_channel_id) + 1
142-
self._session_id_to_channel_id[session._session_id] = channel_id
143-
session._channel_id = channel_id
144-
145-
return session
146-
147135
def session(self, **kwargs):
148136
"""Check out a session from the pool.
149137
@@ -255,6 +243,7 @@ def bind(self, database):
255243
"CloudSpanner.FixedPool.BatchCreateSessions",
256244
observability_options=observability_options,
257245
) as span:
246+
attempt = 1
258247
returned_session_count = 0
259248
while not self._sessions.full():
260249
request.session_count = requested_session_count - self._sessions.qsize()
@@ -263,9 +252,12 @@ def bind(self, database):
263252
f"Creating {request.session_count} sessions",
264253
span_event_attributes,
265254
)
255+
all_metadata = database.metadata_with_request_id(
256+
database._next_nth_request, attempt, metadata
257+
)
266258
resp = api.batch_create_sessions(
267259
request=request,
268-
metadata=metadata,
260+
metadata=all_metadata,
269261
)
270262

271263
add_span_event(

google/cloud/spanner_v1/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(self, database, labels=None, database_role=None):
7575
self._labels = labels
7676
self._database_role = database_role
7777
self._last_use_time = datetime.utcnow()
78-
self.__channel_id = 0
7978

8079
def __lt__(self, other):
8180
return self._session_id < other._session_id

google/cloud/spanner_v1/snapshot.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,8 @@ def read(
328328

329329
def wrapped_restart(*args, **kwargs):
330330
attempt.increment()
331-
channel_id = getattr(self._session, "_channel_id", 0)
332-
client_id = getattr(database, "_nth_client_id", 0)
333-
all_metadata = _metadata_with_request_id(
334-
client_id, channel_id, nth_request, attempt.value, metadata
331+
all_metadata = database.metadata_with_request_id(
332+
nth_request, attempt.value, metadata
335333
)
336334

337335
restart = functools.partial(
@@ -557,10 +555,8 @@ def execute_sql(
557555

558556
def wrapped_restart(*args, **kwargs):
559557
attempt.increment()
560-
channel_id = getattr(self._session, "_channel_id", 0)
561-
client_id = getattr(database, "_nth_client_id", 0)
562-
all_metadata = _metadata_with_request_id(
563-
client_id, channel_id, nth_request, attempt.value, metadata
558+
all_metadata = database.metadata_with_request_id(
559+
nth_request, attempt.value, metadata
564560
)
565561

566562
restart = functools.partial(
@@ -717,10 +713,8 @@ def partition_read(
717713

718714
def wrapped_method(*args, **kwargs):
719715
attempt.increment()
720-
channel_id = getattr(self._session, "_channel_id", 0)
721-
client_id = getattr(database, "_nth_client_id", 0)
722-
all_metadata = _metadata_with_request_id(
723-
client_id, channel_id, nth_request, attempt.value, metadata
716+
all_metadata = database.metadata_with_request_id(
717+
nth_request, attempt.value, metadata
724718
)
725719
method = functools.partial(
726720
api.partition_read,
@@ -832,12 +826,9 @@ def partition_query(
832826

833827
def wrapped_method(*args, **kwargs):
834828
attempt.increment()
835-
channel_id = getattr(self._session, "_channel_id", 0)
836-
client_id = getattr(database, "_nth_client_id", 0)
837-
all_metadata = _metadata_with_request_id(
838-
client_id, channel_id, nth_request, attempt.value, metadata
829+
all_metadata = database.metadata_with_request_id(
830+
nth_request, attempt.value, metadata
839831
)
840-
841832
method = functools.partial(
842833
api.partition_query,
843834
request=request,
@@ -991,12 +982,9 @@ def begin(self):
991982

992983
def wrapped_method(*args, **kwargs):
993984
attempt.increment()
994-
channel_id = getattr(self._session, "_channel_id", 0)
995-
client_id = getattr(database, "_nth_client_id", 0)
996-
all_metadata = _metadata_with_request_id(
997-
client_id, channel_id, nth_request, attempt.value, metadata
985+
all_metadata = database.metadata_with_request_id(
986+
nth_request, attempt.value, metadata
998987
)
999-
1000988
method = functools.partial(
1001989
api.begin_transaction,
1002990
session=self._session.name,

google/cloud/spanner_v1/testing/database_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class TestDatabase(Database):
3535
currently, and we don't want to make changes in the Database class for
3636
testing purpose as this is a hack to use interceptors in tests."""
3737

38+
_interceptors = []
39+
3840
def __init__(
3941
self,
4042
database_id,
@@ -61,11 +63,9 @@ def __init__(
6163

6264
self._method_count_interceptor = MethodCountInterceptor()
6365
self._method_abort_interceptor = MethodAbortInterceptor()
64-
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
6566
self._interceptors = [
6667
self._method_count_interceptor,
6768
self._method_abort_interceptor,
68-
self._x_goog_request_id_interceptor,
6969
]
7070

7171
@property
@@ -77,6 +77,8 @@ def spanner_api(self):
7777
client_options = client._client_options
7878
if self._instance.emulator_host is not None:
7979
channel = grpc.insecure_channel(self._instance.emulator_host)
80+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
81+
self._interceptors.append(self._x_goog_request_id_interceptor)
8082
channel = grpc.intercept_channel(channel, *self._interceptors)
8183
transport = SpannerGrpcTransport(channel=channel)
8284
self._spanner_api = SpannerClient(

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,28 @@ def intercept(self, method, request_or_iterator, call_details):
8282
break
8383

8484
if not x_goog_request_id:
85-
raise Exception(f"Missing {x_goog_request_id}")
86-
87-
streaming = hasattr(request_or_iterator, "__iter__", False)
85+
raise Exception("Missing x_goog_request_id header")
86+
87+
response_or_iterator = method(request_or_iterator, call_details)
88+
streaming = getattr(response_or_iterator, "__iter__", None) is not None
89+
print(
90+
"intercept got",
91+
x_goog_request_id,
92+
call_details.method,
93+
"streaming",
94+
streaming,
95+
)
8896
with self.__lock:
8997
if streaming:
90-
self._stream_req_segments.append(x_goog_request_id)
98+
self._stream_req_segments.append(
99+
(call_details.method, parse_request_id(x_goog_request_id))
100+
)
91101
else:
92-
self._unary_req_segments.append(x_goog_request_id)
102+
self._unary_req_segments.append(
103+
(call_details.method, parse_request_id(x_goog_request_id))
104+
)
93105

94-
return method(request_or_iterator, call_details)
106+
return response_or_iterator
95107

96108
@property
97109
def unary_request_ids(self):
@@ -105,3 +117,18 @@ def reset(self):
105117
self._stream_req_segments.clear()
106118
self._unary_req_segments.clear()
107119
pass
120+
121+
122+
def parse_request_id(request_id_str):
123+
splits = request_id_str.split(".")
124+
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
125+
map(lambda v: int(v), splits)
126+
)
127+
return (
128+
version,
129+
rand_process_id,
130+
client_id,
131+
channel_id,
132+
nth_request,
133+
nth_attempt,
134+
)

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def __init__(self, *args, **kwargs):
118118
self._client = None
119119
self._instance = None
120120
self._database = None
121-
self._interceptors = None
122121

123122
@classmethod
124123
def setup_class(cls):
@@ -147,19 +146,11 @@ def teardown_method(self, *args, **kwargs):
147146
@property
148147
def client(self) -> Client:
149148
if self._client is None:
150-
api_endpoint = "localhost:" + str(MockServerTestBase.port)
151-
channel = grpc.insecure_channel(api_endpoint)
152-
transport = None
153-
if self._interceptors and len(self._interceptors) > 0:
154-
channel = grpc.intercept_channel(channel, *self._interceptors)
155-
transport = SpannerGrpcTransport(channel=channel)
156-
157149
self._client = Client(
158150
project="p",
159151
credentials=AnonymousCredentials(),
160152
client_options=ClientOptions(
161-
transport=transport,
162-
api_endpoint=api_endpoint if transport is None else None,
153+
api_endpoint="localhost:" + str(MockServerTestBase.port),
163154
),
164155
)
165156
return self._client
@@ -174,6 +165,8 @@ def instance(self) -> Instance:
174165
def database(self) -> Database:
175166
if self._database is None:
176167
self._database = self.instance.database(
177-
"test-database", pool=FixedSizePool(size=10)
168+
"test-database",
169+
pool=FixedSizePool(size=10),
170+
enable_interceptors_in_tests=True,
178171
)
179172
return self._database

0 commit comments

Comments
 (0)