@@ -737,27 +737,73 @@ class InterceptingHeaderInjector:
737737 def __init__ (self , original_callable : Callable ):
738738 self ._original_callable = original_callable
739739
740- def __call__ (self , * args , ** kwargs ):
741- metadata = kwargs .get ("metadata" , [])
742- # Find all the headers that match the x-goog-spanner-request-id
743- # header an on each retry increment the value.
744- all_metadata = []
745- for key , value in metadata :
746- if key is REQ_ID_HEADER_KEY :
747- # Otherwise now increment the count for the attempt number.
748- splits = value .split ("." )
749- attempt_plus_one = int (splits [- 1 ]) + 1
750- splits [- 1 ] = str (attempt_plus_one )
751- value_before = value
752- value = "." .join (splits )
753- print ("incrementing value on retry from" , value_before , "to" , value )
754-
755- all_metadata .append (
756- (
757- key ,
758- value ,
759- )
760- )
761740
762- kwargs ["metadata" ] = all_metadata
763- return self ._original_callable (* args , ** kwargs )
741+ patched = {}
742+
743+
744+ def inject_retry_header_control (api ):
745+ # For each method, add an _attempt value that'll then be
746+ # retrieved for each retry.
747+ # 1. Patch the __getattribute__ method to match items in our manifest.
748+ target = type (api )
749+ hex_id = hex (id (target ))
750+ if patched .get (hex_id , None ) is not None :
751+ return
752+
753+ orig_getattribute = getattr (target , "__getattribute__" )
754+
755+ def patched_getattribute (* args , ** kwargs ):
756+ attr = orig_getattribute (* args , ** kwargs )
757+
758+ # 0. If we already patched it, we can return immediately.
759+ if getattr (attr , "_patched" , None ) is not None :
760+ return attr
761+
762+ # 1. Skip over non-methods.
763+ if not callable (attr ):
764+ return attr
765+
766+ # 2. Skip modifying private and mangled methods.
767+ mangled_or_private = attr .__name__ .startswith ("_" )
768+ if mangled_or_private :
769+ return attr
770+
771+ print ("\033 [35mattr" , attr , "hex_id" , hex (id (attr )), "\033 [00m" )
772+
773+ # 3. Wrap the callable attribute and then capture its metadata keyed argument.
774+ def wrapped_attr (* args , ** kwargs ):
775+ metadata = kwargs .get ("metadata" , [])
776+ if not metadata :
777+ # Increment the reinvocation count.
778+ print ("not metatadata" , attr .__name__ )
779+ wrapped_attr ._attempt += 1
780+ return attr (* args , ** kwargs )
781+
782+ # 4. Find all the headers that match the target header key.
783+ all_metadata = []
784+ for key , value in metadata :
785+ if key is REQ_ID_HEADER_KEY :
786+ print ("key" , key , "value" , value , "attempt" , wrapped_attr ._attempt )
787+ # 5. Increment the original_attempt with that of our re-invocation count.
788+ splits = value .split ("." )
789+ hdr_attempt_plus_reinvocation = (
790+ int (splits [- 1 ]) + wrapped_attr ._attempt
791+ )
792+ splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
793+ value = "." .join (splits )
794+
795+ all_metadata .append ((key , value ))
796+
797+ # Increment the reinvocation count.
798+ wrapped_attr ._attempt += 1
799+
800+ kwargs ["metadata" ] = all_metadata
801+ print ("\033 [34mwrap_callable" , hex (id (attr )), attr .__name__ , "\033 [00m" )
802+ return attr (* args , ** kwargs )
803+
804+ wrapped_attr ._attempt = 0
805+ wrapped_attr ._patched = True
806+ return wrapped_attr
807+
808+ setattr (target , "__getattribute__" , patched_getattribute )
809+ patched [hex_id ] = True
0 commit comments