@@ -311,6 +311,37 @@ def __init__(self, key, value):
311311 self .request_metadata = []
312312 self .response_metadata = []
313313
314+ class _StreamWrapper :
315+ def __init__ (self , stream , interceptor ):
316+ self ._stream = stream
317+ self ._interceptor = interceptor
318+
319+ def __iter__ (self ):
320+ return self ._stream .__iter__ ()
321+
322+ def __next__ (self ):
323+ return self ._stream .__next__ ()
324+
325+ def cancel (self ):
326+ return self ._stream .cancel ()
327+
328+ def is_active (self ):
329+ return self ._stream .is_active ()
330+
331+ def time_remaining (self ):
332+ return self ._stream .time_remaining ()
333+
334+ def trailing_metadata (self ):
335+ metadata = self ._stream .trailing_metadata ()
336+ self ._interceptor .response_metadata = metadata
337+ return metadata
338+
339+ def initial_metadata (self ):
340+ return self ._stream .initial_metadata ()
341+
342+ def add_callback (self , callback ):
343+ return self ._stream .add_callback (callback )
344+
314345 def _add_request_metadata (self , client_call_details ):
315346 if client_call_details .metadata is not None :
316347 # https://grpc.github.io/grpc/python/glossary.html#term-metadata.
@@ -320,11 +351,6 @@ def _add_request_metadata(self, client_call_details):
320351 client_call_details .metadata .append ((self ._key , self ._value ))
321352 self .request_metadata = client_call_details .metadata
322353
323- def _read_response_metadata_stream (self ):
324- # Access the metadata via the original stream object
325- if hasattr (self , "_original_stream" ):
326- self .response_metadata = self ._original_stream .trailing_metadata ()
327-
328354 def intercept_unary_unary (self , continuation , client_call_details , request ):
329355 self ._add_request_metadata (client_call_details )
330356 response = continuation (client_call_details , request )
@@ -335,8 +361,7 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
335361 def intercept_unary_stream (self , continuation , client_call_details , request ):
336362 self ._add_request_metadata (client_call_details )
337363 response_it = continuation (client_call_details , request )
338- self ._original_stream = response_it
339- return response_it
364+ return self ._StreamWrapper (response_it , self )
340365
341366 def intercept_stream_unary (
342367 self , continuation , client_call_details , request_iterator
@@ -350,8 +375,7 @@ def intercept_stream_stream(
350375 ):
351376 self ._add_request_metadata (client_call_details )
352377 response_it = continuation (client_call_details , request_iterator )
353- self ._original_stream = response_it
354- return response_it
378+ return self ._StreamWrapper (response_it , self )
355379
356380
357381class EchoMetadataClientGrpcAsyncInterceptor (
@@ -366,6 +390,37 @@ def __init__(self, key, value):
366390 self .request_metadata = []
367391 self .response_metadata = []
368392
393+ class _AsyncStreamWrapper :
394+ def __init__ (self , stream , interceptor ):
395+ self ._stream = stream
396+ self ._interceptor = interceptor
397+
398+ def __aiter__ (self ):
399+ return self ._stream .__aiter__ ()
400+
401+ def __anext__ (self ):
402+ return self ._stream .__anext__ ()
403+
404+ def cancel (self ):
405+ return self ._stream .cancel ()
406+
407+ def done (self ):
408+ return self ._stream .done ()
409+
410+ def add_done_callback (self , callback ):
411+ return self ._stream .add_done_callback (callback )
412+
413+ async def initial_metadata (self ):
414+ return await self ._stream .initial_metadata ()
415+
416+ async def trailing_metadata (self ):
417+ metadata = await self ._stream .trailing_metadata ()
418+ self ._interceptor .response_metadata = metadata
419+ return metadata
420+
421+ async def debug_string (self ):
422+ return await self ._stream .debug_string ()
423+
369424 async def _add_request_metadata (self , client_call_details ):
370425 if client_call_details .metadata is not None :
371426 # As of gRPC 1.75.0 and newer,
@@ -393,22 +448,22 @@ async def intercept_unary_unary(self, continuation, client_call_details, request
393448
394449 async def intercept_unary_stream (self , continuation , client_call_details , request ):
395450 self ._add_request_metadata (client_call_details )
396- response_it = continuation (client_call_details , request )
397- return response_it
451+ response_it = await continuation (client_call_details , request )
452+ return self . _AsyncStreamWrapper ( response_it , self )
398453
399454 async def intercept_stream_unary (
400455 self , continuation , client_call_details , request_iterator
401456 ):
402457 self ._add_request_metadata (client_call_details )
403- response = continuation (client_call_details , request_iterator )
458+ response = await continuation (client_call_details , request_iterator )
404459 return response
405460
406461 async def intercept_stream_stream (
407462 self , continuation , client_call_details , request_iterator
408463 ):
409464 self ._add_request_metadata (client_call_details )
410- response_it = continuation (client_call_details , request_iterator )
411- return response_it
465+ response_it = await continuation (client_call_details , request_iterator )
466+ return self . _AsyncStreamWrapper ( response_it , self )
412467
413468
414469@pytest .fixture
0 commit comments