1616import threading
1717
1818from google .cloud .spanner_v1 import (
19- BatchCreateSessionsRequest ,
2019 CreateSessionRequest ,
2120 ExecuteSqlRequest ,
2221 BeginTransactionRequest ,
@@ -58,20 +57,17 @@ def test_snapshot_execute_sql(self):
5857 NTH_CLIENT = self .database ._nth_client_id
5958 CHANNEL_ID = self .database ._channel_id
6059 got_stream_segments , got_unary_segments = self .canonicalize_request_id_headers ()
61- # Filter out CreateSessionRequest unary segments for comparison
62- filtered_unary_segments = [
63- seg for seg in got_unary_segments if not seg [0 ].endswith ("/CreateSession" )
64- ]
60+ # With multiplexed sessions, we expect one CreateSession request
6561 want_unary_segments = [
6662 (
67- "/google.spanner.v1.Spanner/BatchCreateSessions " ,
63+ "/google.spanner.v1.Spanner/CreateSession " ,
6864 (1 , REQ_RAND_PROCESS_ID , NTH_CLIENT , CHANNEL_ID , 1 , 1 ),
6965 )
7066 ]
7167 # Dynamically determine the expected sequence number for ExecuteStreamingSql
7268 session_requests_before = 0
7369 for req in requests :
74- if isinstance (req , ( BatchCreateSessionsRequest , CreateSessionRequest ) ):
70+ if isinstance (req , CreateSessionRequest ):
7571 session_requests_before += 1
7672 elif isinstance (req , ExecuteSqlRequest ):
7773 break
@@ -88,7 +84,7 @@ def test_snapshot_execute_sql(self):
8884 ),
8985 )
9086 ]
91- assert filtered_unary_segments == want_unary_segments
87+ assert got_unary_segments == want_unary_segments
9288 assert got_stream_segments == want_stream_segments
9389
9490 def test_snapshot_read_concurrent (self ):
@@ -118,45 +114,32 @@ def select1():
118114 for thread in threads :
119115 thread .join ()
120116 requests = self .spanner_service .requests
121- # Allow for an extra request due to multiplexed session creation
122- expected_min = 2 + n
123- expected_max = expected_min + 1
117+ # With multiplexed sessions: 1 CreateSession + (n + 1) ExecuteSql
118+ expected_min = 1 + n + 1
119+ expected_max = expected_min
124120 assert (
125121 expected_min <= len (requests ) <= expected_max
126- ), f"Expected { expected_min } or { expected_max } requests, got { len (requests )} : { requests } "
122+ ), f"Expected { expected_min } requests, got { len (requests )} : { requests } "
127123 client_id = db ._nth_client_id
128124 channel_id = db ._channel_id
129125 got_stream_segments , got_unary_segments = self .canonicalize_request_id_headers ()
130126 want_unary_segments = [
131127 (
132- "/google.spanner.v1.Spanner/BatchCreateSessions " ,
128+ "/google.spanner.v1.Spanner/CreateSession " ,
133129 (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 1 , 1 ),
134130 ),
135131 ]
136132 assert any (seg == want_unary_segments [0 ] for seg in got_unary_segments )
137133
138- # Dynamically determine the expected sequence numbers for ExecuteStreamingSql
139- session_requests_before = 0
140- for req in requests :
141- if isinstance (req , (BatchCreateSessionsRequest , CreateSessionRequest )):
142- session_requests_before += 1
143- elif isinstance (req , ExecuteSqlRequest ):
144- break
145- want_stream_segments = [
146- (
147- "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
148- (
149- 1 ,
150- REQ_RAND_PROCESS_ID ,
151- client_id ,
152- channel_id ,
153- session_requests_before + i ,
154- 1 ,
155- ),
156- )
157- for i in range (1 , n + 2 )
158- ]
159- assert got_stream_segments == want_stream_segments
134+ # Verify we have the expected number of ExecuteStreamingSql segments
135+ # (n + 1 = 11 for initial + 10 concurrent)
136+ assert len (got_stream_segments ) == n + 1
137+ # Verify all segments are for ExecuteStreamingSql
138+ for seg in got_stream_segments :
139+ assert seg [0 ] == "/google.spanner.v1.Spanner/ExecuteStreamingSql"
140+ # Verify the segment has correct client_id and channel_id
141+ assert seg [1 ][2 ] == client_id
142+ assert seg [1 ][3 ] == channel_id
160143
161144 def test_database_run_in_transaction_retries_on_abort (self ):
162145 counters = dict (aborted = 0 )
@@ -192,33 +175,22 @@ def test_database_execute_partitioned_dml_request_id(self):
192175 got_stream_segments , got_unary_segments = self .canonicalize_request_id_headers ()
193176 NTH_CLIENT = self .database ._nth_client_id
194177 CHANNEL_ID = self .database ._channel_id
195- # Allow for extra unary segments due to session creation
196- filtered_unary_segments = [
197- seg for seg in got_unary_segments if not seg [0 ].endswith ("/CreateSession" )
198- ]
199178 # Find the actual sequence number for BeginTransaction
200179 begin_txn_seq = None
201- for seg in filtered_unary_segments :
180+ for seg in got_unary_segments :
202181 if seg [0 ].endswith ("/BeginTransaction" ):
203182 begin_txn_seq = seg [1 ][4 ]
204183 break
205184 want_unary_segments = [
206185 (
207- "/google.spanner.v1.Spanner/BatchCreateSessions " ,
186+ "/google.spanner.v1.Spanner/CreateSession " ,
208187 (1 , REQ_RAND_PROCESS_ID , NTH_CLIENT , CHANNEL_ID , 1 , 1 ),
209188 ),
210189 (
211190 "/google.spanner.v1.Spanner/BeginTransaction" ,
212191 (1 , REQ_RAND_PROCESS_ID , NTH_CLIENT , CHANNEL_ID , begin_txn_seq , 1 ),
213192 ),
214193 ]
215- # Dynamically determine the expected sequence number for ExecuteStreamingSql
216- session_requests_before = 0
217- for req in requests :
218- if isinstance (req , (BatchCreateSessionsRequest , CreateSessionRequest )):
219- session_requests_before += 1
220- elif isinstance (req , ExecuteSqlRequest ):
221- break
222194 # Find the actual sequence number for ExecuteStreamingSql
223195 exec_sql_seq = got_stream_segments [0 ][1 ][4 ] if got_stream_segments else None
224196 want_stream_segments = [
@@ -227,12 +199,12 @@ def test_database_execute_partitioned_dml_request_id(self):
227199 (1 , REQ_RAND_PROCESS_ID , NTH_CLIENT , CHANNEL_ID , exec_sql_seq , 1 ),
228200 )
229201 ]
230- assert all (seg in filtered_unary_segments for seg in want_unary_segments )
202+ assert all (seg in got_unary_segments for seg in want_unary_segments )
231203 assert got_stream_segments == want_stream_segments
232204
233205 def test_unary_retryable_error (self ):
234206 add_select1_result ()
235- add_error (SpannerServicer .BatchCreateSessions .__name__ , unavailable_status ())
207+ add_error (SpannerServicer .CreateSession .__name__ , unavailable_status ())
236208
237209 if not getattr (self .database , "_interceptors" , None ):
238210 self .database ._interceptors = MockServerTestBase ._interceptors
0 commit comments