Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 941fe1a

Browse files
committed
fix lint
1 parent 4a2fc45 commit 941fe1a

File tree

8 files changed

+75
-183
lines changed

8 files changed

+75
-183
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,7 @@ def session(self, labels=None, database_role=None):
868868
# instead.
869869
role = database_role or self._database_role
870870
# Always use multiplexed sessions
871-
return Session(
872-
self, labels=labels, database_role=role, is_multiplexed=True
873-
)
871+
return Session(self, labels=labels, database_role=role, is_multiplexed=True)
874872

875873
def snapshot(self, **kw):
876874
"""Return an object which wraps a snapshot.

google/cloud/spanner_v1/database_sessions_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,3 @@ def _maintain_multiplexed_session(session_manager_ref) -> None:
210210
manager._multiplexed_session = manager._build_multiplexed_session()
211211

212212
session_created_time = time()
213-

tests/_helpers.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
from os import getenv
32

43
import mock
54

@@ -36,21 +35,12 @@
3635

3736

3837
def is_multiplexed_enabled(transaction_type: TransactionType) -> bool:
39-
"""Returns whether multiplexed sessions are enabled for the given transaction type."""
38+
"""Returns whether multiplexed sessions are enabled for the given transaction type.
4039
41-
env_var = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS"
42-
env_var_partitioned = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS"
43-
env_var_read_write = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"
44-
45-
def _getenv(val: str) -> bool:
46-
return getenv(val, "true").lower().strip() != "false"
47-
48-
if transaction_type is TransactionType.READ_ONLY:
49-
return _getenv(env_var)
50-
elif transaction_type is TransactionType.PARTITIONED:
51-
return _getenv(env_var) and _getenv(env_var_partitioned)
52-
else:
53-
return _getenv(env_var) and _getenv(env_var_read_write)
40+
Multiplexed sessions are now always enabled for all transaction types.
41+
This function is kept for backward compatibility with existing tests.
42+
"""
43+
return True
5444

5545

5646
def get_test_ot_exporter():

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
SpannerServicer,
4242
start_mock_server,
4343
)
44-
from tests._helpers import is_multiplexed_enabled
4544

4645

4746
# Creates an aborted status with the smallest possible retry delay.
@@ -240,52 +239,30 @@ def assert_requests_sequence(
240239
transaction_type,
241240
allow_multiple_batch_create=True,
242241
):
243-
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
242+
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions.
244243
245244
Args:
246245
requests: List of requests from spanner_service.requests
247246
expected_types: List of expected request types (excluding session creation requests)
248-
transaction_type: TransactionType enum value to check multiplexed session status
249-
allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
247+
transaction_type: TransactionType enum value (unused, kept for backward compatibility)
248+
allow_multiple_batch_create: If True, skip leading CreateSessionRequest (kept for backward compatibility)
250249
"""
251-
from google.cloud.spanner_v1 import (
252-
BatchCreateSessionsRequest,
253-
CreateSessionRequest,
254-
)
250+
from google.cloud.spanner_v1 import CreateSessionRequest
255251

256-
mux_enabled = is_multiplexed_enabled(transaction_type)
257252
idx = 0
258-
# Skip all leading BatchCreateSessionsRequest (for retries)
253+
# Skip CreateSessionRequest for multiplexed session
259254
if allow_multiple_batch_create:
260255
while idx < len(requests) and isinstance(
261-
requests[idx], BatchCreateSessionsRequest
262-
):
263-
idx += 1
264-
# For multiplexed, optionally skip a CreateSessionRequest
265-
if (
266-
mux_enabled
267-
and idx < len(requests)
268-
and isinstance(requests[idx], CreateSessionRequest)
256+
requests[idx], CreateSessionRequest
269257
):
270258
idx += 1
271259
else:
272-
if mux_enabled:
273-
self.assertTrue(
274-
isinstance(requests[idx], BatchCreateSessionsRequest),
275-
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
276-
)
277-
idx += 1
278-
self.assertTrue(
279-
isinstance(requests[idx], CreateSessionRequest),
280-
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}",
281-
)
282-
idx += 1
283-
else:
284-
self.assertTrue(
285-
isinstance(requests[idx], BatchCreateSessionsRequest),
286-
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
287-
)
288-
idx += 1
260+
# Expect exactly one CreateSessionRequest for multiplexed session
261+
self.assertTrue(
262+
isinstance(requests[idx], CreateSessionRequest),
263+
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}",
264+
)
265+
idx += 1
289266
# Check the rest of the expected request types
290267
for expected_type in expected_types:
291268
self.assertTrue(
@@ -303,13 +280,12 @@ def adjust_request_id_sequence(self, expected_segments, requests, transaction_ty
303280
Args:
304281
expected_segments: List of expected (method, (sequence_numbers)) tuples
305282
requests: List of actual requests from spanner_service.requests
306-
transaction_type: TransactionType enum value to check multiplexed session status
283+
transaction_type: TransactionType enum value (unused, kept for backward compatibility)
307284
308285
Returns:
309286
List of adjusted expected segments with corrected sequence numbers
310287
"""
311288
from google.cloud.spanner_v1 import (
312-
BatchCreateSessionsRequest,
313289
CreateSessionRequest,
314290
ExecuteSqlRequest,
315291
BeginTransactionRequest,
@@ -318,15 +294,13 @@ def adjust_request_id_sequence(self, expected_segments, requests, transaction_ty
318294
# Count session creation requests that come before the first non-session request
319295
session_requests_before = 0
320296
for req in requests:
321-
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
297+
if isinstance(req, CreateSessionRequest):
322298
session_requests_before += 1
323299
elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)):
324300
break
325301

326-
# For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
327-
# For non-multiplexed, we expect 1 session request (BatchCreateSessions)
328-
mux_enabled = is_multiplexed_enabled(transaction_type)
329-
expected_session_requests = 2 if mux_enabled else 1
302+
# With multiplexed sessions, we expect 1 session request (CreateSession)
303+
expected_session_requests = 1
330304
extra_session_requests = session_requests_before - expected_session_requests
331305

332306
# Adjust sequence numbers based on extra session requests

tests/mockserver_tests/test_request_id_header.py

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import threading
1717

1818
from google.cloud.spanner_v1 import (
19-
BatchCreateSessionsRequest,
2019
CreateSessionRequest,
2120
ExecuteSqlRequest,
2221
BeginTransactionRequest,
@@ -58,20 +57,17 @@ def test_snapshot_execute_sql(self):
5857
NTH_CLIENT = self.database._nth_client_id
5958
CHANNEL_ID = self.database._channel_id
6059
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
61-
# Filter out CreateSessionRequest unary segments for comparison
62-
filtered_unary_segments = [
63-
seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")
64-
]
60+
# With multiplexed sessions, we expect one CreateSession request
6561
want_unary_segments = [
6662
(
67-
"/google.spanner.v1.Spanner/BatchCreateSessions",
63+
"/google.spanner.v1.Spanner/CreateSession",
6864
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1),
6965
)
7066
]
7167
# Dynamically determine the expected sequence number for ExecuteStreamingSql
7268
session_requests_before = 0
7369
for req in requests:
74-
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
70+
if isinstance(req, CreateSessionRequest):
7571
session_requests_before += 1
7672
elif isinstance(req, ExecuteSqlRequest):
7773
break
@@ -88,7 +84,7 @@ def test_snapshot_execute_sql(self):
8884
),
8985
)
9086
]
91-
assert filtered_unary_segments == want_unary_segments
87+
assert got_unary_segments == want_unary_segments
9288
assert got_stream_segments == want_stream_segments
9389

9490
def test_snapshot_read_concurrent(self):
@@ -118,45 +114,32 @@ def select1():
118114
for thread in threads:
119115
thread.join()
120116
requests = self.spanner_service.requests
121-
# Allow for an extra request due to multiplexed session creation
122-
expected_min = 2 + n
123-
expected_max = expected_min + 1
117+
# With multiplexed sessions: 1 CreateSession + (n + 1) ExecuteSql
118+
expected_min = 1 + n + 1
119+
expected_max = expected_min
124120
assert (
125121
expected_min <= len(requests) <= expected_max
126-
), f"Expected {expected_min} or {expected_max} requests, got {len(requests)}: {requests}"
122+
), f"Expected {expected_min} requests, got {len(requests)}: {requests}"
127123
client_id = db._nth_client_id
128124
channel_id = db._channel_id
129125
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
130126
want_unary_segments = [
131127
(
132-
"/google.spanner.v1.Spanner/BatchCreateSessions",
128+
"/google.spanner.v1.Spanner/CreateSession",
133129
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
134130
),
135131
]
136132
assert any(seg == want_unary_segments[0] for seg in got_unary_segments)
137133

138-
# Dynamically determine the expected sequence numbers for ExecuteStreamingSql
139-
session_requests_before = 0
140-
for req in requests:
141-
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
142-
session_requests_before += 1
143-
elif isinstance(req, ExecuteSqlRequest):
144-
break
145-
want_stream_segments = [
146-
(
147-
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
148-
(
149-
1,
150-
REQ_RAND_PROCESS_ID,
151-
client_id,
152-
channel_id,
153-
session_requests_before + i,
154-
1,
155-
),
156-
)
157-
for i in range(1, n + 2)
158-
]
159-
assert got_stream_segments == want_stream_segments
134+
# Verify we have the expected number of ExecuteStreamingSql segments
135+
# (n + 1 = 11 for initial + 10 concurrent)
136+
assert len(got_stream_segments) == n + 1
137+
# Verify all segments are for ExecuteStreamingSql
138+
for seg in got_stream_segments:
139+
assert seg[0] == "/google.spanner.v1.Spanner/ExecuteStreamingSql"
140+
# Verify the segment has correct client_id and channel_id
141+
assert seg[1][2] == client_id
142+
assert seg[1][3] == channel_id
160143

161144
def test_database_run_in_transaction_retries_on_abort(self):
162145
counters = dict(aborted=0)
@@ -192,33 +175,22 @@ def test_database_execute_partitioned_dml_request_id(self):
192175
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
193176
NTH_CLIENT = self.database._nth_client_id
194177
CHANNEL_ID = self.database._channel_id
195-
# Allow for extra unary segments due to session creation
196-
filtered_unary_segments = [
197-
seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")
198-
]
199178
# Find the actual sequence number for BeginTransaction
200179
begin_txn_seq = None
201-
for seg in filtered_unary_segments:
180+
for seg in got_unary_segments:
202181
if seg[0].endswith("/BeginTransaction"):
203182
begin_txn_seq = seg[1][4]
204183
break
205184
want_unary_segments = [
206185
(
207-
"/google.spanner.v1.Spanner/BatchCreateSessions",
186+
"/google.spanner.v1.Spanner/CreateSession",
208187
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1),
209188
),
210189
(
211190
"/google.spanner.v1.Spanner/BeginTransaction",
212191
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, begin_txn_seq, 1),
213192
),
214193
]
215-
# Dynamically determine the expected sequence number for ExecuteStreamingSql
216-
session_requests_before = 0
217-
for req in requests:
218-
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
219-
session_requests_before += 1
220-
elif isinstance(req, ExecuteSqlRequest):
221-
break
222194
# Find the actual sequence number for ExecuteStreamingSql
223195
exec_sql_seq = got_stream_segments[0][1][4] if got_stream_segments else None
224196
want_stream_segments = [
@@ -227,12 +199,12 @@ def test_database_execute_partitioned_dml_request_id(self):
227199
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1),
228200
)
229201
]
230-
assert all(seg in filtered_unary_segments for seg in want_unary_segments)
202+
assert all(seg in got_unary_segments for seg in want_unary_segments)
231203
assert got_stream_segments == want_stream_segments
232204

233205
def test_unary_retryable_error(self):
234206
add_select1_result()
235-
add_error(SpannerServicer.BatchCreateSessions.__name__, unavailable_status())
207+
add_error(SpannerServicer.CreateSession.__name__, unavailable_status())
236208

237209
if not getattr(self.database, "_interceptors", None):
238210
self.database._interceptors = MockServerTestBase._interceptors

tests/mockserver_tests/test_tags.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
MockServerTestBase,
2424
add_single_result,
2525
)
26-
from tests._helpers import is_multiplexed_enabled
2726
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
2827

2928

@@ -100,8 +99,9 @@ def test_select_read_only_transaction_with_transaction_tag(self):
10099
TransactionType.READ_ONLY,
101100
)
102101
# Transaction tags are not supported for read-only transactions.
103-
mux_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY)
104-
tag_idx = 3 if mux_enabled else 2
102+
# With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql
103+
# ExecuteSql requests start at index 2
104+
tag_idx = 2
105105
self.assertEqual("", requests[tag_idx].request_options.transaction_tag)
106106
self.assertEqual("", requests[tag_idx + 1].request_options.transaction_tag)
107107

@@ -155,8 +155,9 @@ def test_select_read_write_transaction_with_transaction_tag(self):
155155
],
156156
TransactionType.READ_WRITE,
157157
)
158-
mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE)
159-
tag_idx = 3 if mux_enabled else 2
158+
# With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql, Commit
159+
# ExecuteSql requests start at index 2, Commit at index 4
160+
tag_idx = 2
160161
self.assertEqual(
161162
"my_transaction_tag", requests[tag_idx].request_options.transaction_tag
162163
)
@@ -187,8 +188,9 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self):
187188
],
188189
TransactionType.READ_WRITE,
189190
)
190-
mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE)
191-
tag_idx = 3 if mux_enabled else 2
191+
# With multiplexed sessions: CreateSession, BeginTransaction, ExecuteSql, ExecuteSql, Commit
192+
# ExecuteSql requests start at index 2, Commit at index 4
193+
tag_idx = 2
192194
self.assertEqual(
193195
"my_transaction_tag", requests[tag_idx].request_options.transaction_tag
194196
)

tests/system/test_database_api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,7 @@ def test_update_ddl_w_operation_id(
249249
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/5629
250250
# )
251251
temp_db_id = _helpers.unique_id("update_ddl", separator="_")
252-
temp_db = shared_instance.database(
253-
temp_db_id, database_dialect=database_dialect
254-
)
252+
temp_db = shared_instance.database(temp_db_id, database_dialect=database_dialect)
255253
create_op = temp_db.create()
256254
databases_to_delete.append(temp_db)
257255
create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout.

0 commit comments

Comments
 (0)