3535 _merge_query_options ,
3636 _metadata_with_prefix ,
3737 _metadata_with_leader_aware_routing ,
38+ _metadata_with_request_id ,
3839 _retry ,
3940 _check_rst_stream_error ,
4041 _SessionWrapper ,
42+ AtomicCounter ,
4143)
4244from google .cloud .spanner_v1 ._opentelemetry_tracing import trace_call
4345from google .cloud .spanner_v1 .streamed import StreamedResultSet
@@ -320,13 +322,26 @@ def read(
320322 data_boost_enabled = data_boost_enabled ,
321323 directed_read_options = directed_read_options ,
322324 )
323- restart = functools .partial (
324- api .streaming_read ,
325- request = request ,
326- metadata = metadata ,
327- retry = retry ,
328- timeout = timeout ,
329- )
325+
326+ nth_request = getattr (database , "_next_nth_request" , 0 )
327+ attempt = AtomicCounter (0 )
328+
329+ def wrapped_restart (* args , ** kwargs ):
330+ attempt .increment ()
331+ channel_id = getattr (self ._session , "_channel_id" , 0 )
332+ client_id = getattr (database , "_nth_client_id" , 0 )
333+ all_metadata = _metadata_with_request_id (
334+ client_id , channel_id , nth_request , attempt .value , metadata
335+ )
336+
337+ restart = functools .partial (
338+ api .streaming_read ,
339+ request = request ,
340+ metadata = all_metadata ,
341+ retry = retry ,
342+ timeout = timeout ,
343+ )
344+ return restart (* args , ** kwargs )
330345
331346 trace_attributes = {"table_id" : table , "columns" : columns }
332347 observability_options = getattr (database , "observability_options" , None )
@@ -335,7 +350,7 @@ def read(
335350 # lock is added to handle the inline begin for first rpc
336351 with self ._lock :
337352 iterator = _restart_on_unavailable (
338- restart ,
353+ wrapped_restart ,
339354 request ,
340355 "CloudSpanner.ReadOnlyTransaction" ,
341356 self ._session ,
@@ -357,7 +372,7 @@ def read(
357372 )
358373 else :
359374 iterator = _restart_on_unavailable (
360- restart ,
375+ wrapped_restart ,
361376 request ,
362377 "CloudSpanner.ReadOnlyTransaction" ,
363378 self ._session ,
@@ -536,13 +551,27 @@ def execute_sql(
536551 data_boost_enabled = data_boost_enabled ,
537552 directed_read_options = directed_read_options ,
538553 )
539- restart = functools .partial (
540- api .execute_streaming_sql ,
541- request = request ,
542- metadata = metadata ,
543- retry = retry ,
544- timeout = timeout ,
545- )
554+
555+ nth_request = getattr (database , "_next_nth_request" , 0 )
556+ attempt = AtomicCounter (0 )
557+
558+ def wrapped_restart (* args , ** kwargs ):
559+ attempt .increment ()
560+ channel_id = getattr (self ._session , "_channel_id" , 0 )
561+ client_id = getattr (database , "_nth_client_id" , 0 )
562+ all_metadata = _metadata_with_request_id (
563+ client_id , channel_id , nth_request , attempt .value , metadata
564+ )
565+
566+ restart = functools .partial (
567+ api .execute_streaming_sql ,
568+ request = request ,
569+ metadata = all_metadata ,
570+ retry = retry ,
571+ timeout = timeout ,
572+ )
573+
574+ return restart (* args , ** kwargs )
546575
547576 trace_attributes = {"db.statement" : sql }
548577 observability_options = getattr (database , "observability_options" , None )
@@ -551,7 +580,7 @@ def execute_sql(
551580 # lock is added to handle the inline begin for first rpc
552581 with self ._lock :
553582 return self ._get_streamed_result_set (
554- restart ,
583+ wrapped_restart ,
555584 request ,
556585 trace_attributes ,
557586 column_info ,
@@ -560,7 +589,7 @@ def execute_sql(
560589 )
561590 else :
562591 return self ._get_streamed_result_set (
563- restart ,
592+ wrapped_restart ,
564593 request ,
565594 trace_attributes ,
566595 column_info ,
@@ -683,15 +712,27 @@ def partition_read(
683712 trace_attributes ,
684713 observability_options = getattr (database , "observability_options" , None ),
685714 ):
686- method = functools .partial (
687- api .partition_read ,
688- request = request ,
689- metadata = metadata ,
690- retry = retry ,
691- timeout = timeout ,
692- )
715+ nth_request = getattr (database , "_next_nth_request" , 0 )
716+ attempt = AtomicCounter (0 )
717+
718+ def wrapped_method (* args , ** kwargs ):
719+ attempt .increment ()
720+ channel_id = getattr (self ._session , "_channel_id" , 0 )
721+ client_id = getattr (database , "_nth_client_id" , 0 )
722+ all_metadata = _metadata_with_request_id (
723+ client_id , channel_id , nth_request , attempt .value , metadata
724+ )
725+ method = functools .partial (
726+ api .partition_read ,
727+ request = request ,
728+ metadata = all_metadata ,
729+ retry = retry ,
730+ timeout = timeout ,
731+ )
732+ return method (* args , ** kwargs )
733+
693734 response = _retry (
694- method ,
735+ wrapped_method ,
695736 allowed_exceptions = {InternalServerError : _check_rst_stream_error },
696737 )
697738
@@ -786,15 +827,28 @@ def partition_query(
786827 trace_attributes ,
787828 observability_options = getattr (database , "observability_options" , None ),
788829 ):
789- method = functools .partial (
790- api .partition_query ,
791- request = request ,
792- metadata = metadata ,
793- retry = retry ,
794- timeout = timeout ,
795- )
830+ nth_request = getattr (database , "_next_nth_request" , 0 )
831+ attempt = AtomicCounter (0 )
832+
833+ def wrapped_method (* args , ** kwargs ):
834+ attempt .increment ()
835+ channel_id = getattr (self ._session , "_channel_id" , 0 )
836+ client_id = getattr (database , "_nth_client_id" , 0 )
837+ all_metadata = _metadata_with_request_id (
838+ client_id , channel_id , nth_request , attempt .value , metadata
839+ )
840+
841+ method = functools .partial (
842+ api .partition_query ,
843+ request = request ,
844+ metadata = all_metadata ,
845+ retry = retry ,
846+ timeout = timeout ,
847+ )
848+ return method (* args , ** kwargs )
849+
796850 response = _retry (
797- method ,
851+ wrapped_method ,
798852 allowed_exceptions = {InternalServerError : _check_rst_stream_error },
799853 )
800854
@@ -932,14 +986,27 @@ def begin(self):
932986 self ._session ,
933987 observability_options = getattr (database , "observability_options" , None ),
934988 ):
935- method = functools .partial (
936- api .begin_transaction ,
937- session = self ._session .name ,
938- options = txn_selector .begin ,
939- metadata = metadata ,
940- )
989+ nth_request = getattr (database , "_next_nth_request" , 0 )
990+ attempt = AtomicCounter (0 )
991+
992+ def wrapped_method (* args , ** kwargs ):
993+ attempt .increment ()
994+ channel_id = getattr (self ._session , "_channel_id" , 0 )
995+ client_id = getattr (database , "_nth_client_id" , 0 )
996+ all_metadata = _metadata_with_request_id (
997+ client_id , channel_id , nth_request , attempt .value , metadata
998+ )
999+
1000+ method = functools .partial (
1001+ api .begin_transaction ,
1002+ session = self ._session .name ,
1003+ options = txn_selector .begin ,
1004+ metadata = all_metadata ,
1005+ )
1006+ return method (* args , ** kwargs )
1007+
9411008 response = _retry (
942- method ,
1009+ wrapped_method ,
9431010 allowed_exceptions = {InternalServerError : _check_rst_stream_error },
9441011 )
9451012 self ._transaction_id = response .id
0 commit comments