2222from google .cloud .spanner_v1 .testing .interceptors import XGoogRequestIDHeaderInterceptor
2323from google .cloud .spanner_v1 import (
2424 BatchCreateSessionsRequest ,
25+ BeginTransactionRequest ,
2526 ExecuteSqlRequest ,
2627)
2728from google .api_core .exceptions import Aborted
@@ -195,7 +196,7 @@ def select1():
195196 ]
196197 assert got_stream_segments == want_stream_segments
197198
198- def test_retries_on_abort (self ):
199+ def test_database_run_in_transaction_retries_on_abort (self ):
199200 counters = dict (aborted = 0 )
200201 want_failed_attempts = 2
201202
@@ -217,10 +218,157 @@ def select_in_txn(txn):
217218
218219 self .database .run_in_transaction (select_in_txn )
219220
221+ def test_database_execute_partitioned_dml_request_id (self ):
222+ add_select1_result ()
223+ if not getattr (self .database , "_interceptors" , None ):
224+ self .database ._interceptors = MockServerTestBase ._interceptors
225+ _ = self .database .execute_partitioned_dml ("select 1" )
226+
227+ requests = self .spanner_service .requests
228+ self .assertEqual (3 , len (requests ), msg = requests )
229+ self .assertTrue (isinstance (requests [0 ], BatchCreateSessionsRequest ))
230+ self .assertTrue (isinstance (requests [1 ], BeginTransactionRequest ))
231+ self .assertTrue (isinstance (requests [2 ], ExecuteSqlRequest ))
232+
233+ # Now ensure monotonicity of the received request-id segments.
234+ got_stream_segments , got_unary_segments = self .canonicalize_request_id_headers ()
235+ want_unary_segments = [
236+ (
237+ "/google.spanner.v1.Spanner/BatchCreateSessions" ,
238+ (1 , REQ_RAND_PROCESS_ID , 1 , 1 , 1 , 1 ),
239+ ),
240+ (
241+ "/google.spanner.v1.Spanner/BeginTransaction" ,
242+ (1 , REQ_RAND_PROCESS_ID , 1 , 1 , 2 , 1 ),
243+ ),
244+ ]
245+ want_stream_segments = [
246+ (
247+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
248+ (1 , REQ_RAND_PROCESS_ID , 1 , 1 , 3 , 1 ),
249+ )
250+ ]
251+
252+ assert got_unary_segments == want_unary_segments
253+ assert got_stream_segments == want_stream_segments
254+
255+ def test_snapshot_read (self ):
256+ add_select1_result ()
257+ if not getattr (self .database , "_interceptors" , None ):
258+ self .database ._interceptors = MockServerTestBase ._interceptors
259+ with self .database .snapshot () as snapshot :
260+ results = snapshot .read ("select 1" )
261+ result_list = []
262+ for row in results :
263+ result_list .append (row )
264+ self .assertEqual (1 , row [0 ])
265+ self .assertEqual (1 , len (result_list ))
266+
267+ requests = self .spanner_service .requests
268+ self .assertEqual (2 , len (requests ), msg = requests )
269+ self .assertTrue (isinstance (requests [0 ], BatchCreateSessionsRequest ))
270+ self .assertTrue (isinstance (requests [1 ], ExecuteSqlRequest ))
271+
272+ requests = self .spanner_service .requests
273+ self .assertEqual (n * 2 , len (requests ), msg = requests )
274+
275+ client_id = self .database ._nth_client_id
276+ channel_id = self .database ._channel_id
277+ got_stream_segments , got_unary_segments = self .canonicalize_request_id_headers ()
278+
279+ want_unary_segments = [
280+ (
281+ "/google.spanner.v1.Spanner/BatchCreateSessions" ,
282+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 1 , 1 ),
283+ ),
284+ (
285+ "/google.spanner.v1.Spanner/GetSession" ,
286+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 3 , 1 ),
287+ ),
288+ (
289+ "/google.spanner.v1.Spanner/GetSession" ,
290+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 5 , 1 ),
291+ ),
292+ (
293+ "/google.spanner.v1.Spanner/GetSession" ,
294+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 7 , 1 ),
295+ ),
296+ (
297+ "/google.spanner.v1.Spanner/GetSession" ,
298+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 9 , 1 ),
299+ ),
300+ (
301+ "/google.spanner.v1.Spanner/GetSession" ,
302+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 11 , 1 ),
303+ ),
304+ (
305+ "/google.spanner.v1.Spanner/GetSession" ,
306+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 13 , 1 ),
307+ ),
308+ (
309+ "/google.spanner.v1.Spanner/GetSession" ,
310+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 15 , 1 ),
311+ ),
312+ (
313+ "/google.spanner.v1.Spanner/GetSession" ,
314+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 17 , 1 ),
315+ ),
316+ (
317+ "/google.spanner.v1.Spanner/GetSession" ,
318+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 19 , 1 ),
319+ ),
320+ ]
321+ assert got_unary_segments == want_unary_segments
322+
323+ want_stream_segments = [
324+ (
325+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
326+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 2 , 1 ),
327+ ),
328+ (
329+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
330+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 4 , 1 ),
331+ ),
332+ (
333+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
334+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 6 , 1 ),
335+ ),
336+ (
337+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
338+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 8 , 1 ),
339+ ),
340+ (
341+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
342+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 10 , 1 ),
343+ ),
344+ (
345+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
346+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 12 , 1 ),
347+ ),
348+ (
349+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
350+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 14 , 1 ),
351+ ),
352+ (
353+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
354+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 16 , 1 ),
355+ ),
356+ (
357+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
358+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 18 , 1 ),
359+ ),
360+ (
361+ "/google.spanner.v1.Spanner/ExecuteStreamingSql" ,
362+ (1 , REQ_RAND_PROCESS_ID , client_id , channel_id , 20 , 1 ),
363+ ),
364+ ]
365+ assert got_stream_segments == want_stream_segments
366+
220367 def canonicalize_request_id_headers (self ):
221368 src = self .database ._x_goog_request_id_interceptor
222369 return src ._stream_req_segments , src ._unary_req_segments
223370
371+
224372class FauxCall :
225373 def __init__ (self , code , details = "FauxCall" ):
226374 self ._code = code
0 commit comments