Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit 79ece42

Browse files
committed
wip
1 parent 4fe6a57 commit 79ece42

1 file changed

Lines changed: 69 additions & 14 deletions

File tree

tests/system/conftest.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,37 @@ def __init__(self, key, value):
311311
self.request_metadata = []
312312
self.response_metadata = []
313313

314+
class _StreamWrapper:
315+
def __init__(self, stream, interceptor):
316+
self._stream = stream
317+
self._interceptor = interceptor
318+
319+
def __iter__(self):
320+
return self._stream.__iter__()
321+
322+
def __next__(self):
323+
return self._stream.__next__()
324+
325+
def cancel(self):
326+
return self._stream.cancel()
327+
328+
def is_active(self):
329+
return self._stream.is_active()
330+
331+
def time_remaining(self):
332+
return self._stream.time_remaining()
333+
334+
def trailing_metadata(self):
335+
metadata = self._stream.trailing_metadata()
336+
self._interceptor.response_metadata = metadata
337+
return metadata
338+
339+
def initial_metadata(self):
340+
return self._stream.initial_metadata()
341+
342+
def add_callback(self, callback):
343+
return self._stream.add_callback(callback)
344+
314345
def _add_request_metadata(self, client_call_details):
315346
if client_call_details.metadata is not None:
316347
# https://grpc.github.io/grpc/python/glossary.html#term-metadata.
@@ -320,11 +351,6 @@ def _add_request_metadata(self, client_call_details):
320351
client_call_details.metadata.append((self._key, self._value))
321352
self.request_metadata = client_call_details.metadata
322353

323-
def _read_response_metadata_stream(self):
324-
# Access the metadata via the original stream object
325-
if hasattr(self, "_original_stream"):
326-
self.response_metadata = self._original_stream.trailing_metadata()
327-
328354
def intercept_unary_unary(self, continuation, client_call_details, request):
329355
self._add_request_metadata(client_call_details)
330356
response = continuation(client_call_details, request)
@@ -335,8 +361,7 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
335361
def intercept_unary_stream(self, continuation, client_call_details, request):
336362
self._add_request_metadata(client_call_details)
337363
response_it = continuation(client_call_details, request)
338-
self._original_stream = response_it
339-
return response_it
364+
return self._StreamWrapper(response_it, self)
340365

341366
def intercept_stream_unary(
342367
self, continuation, client_call_details, request_iterator
@@ -350,8 +375,7 @@ def intercept_stream_stream(
350375
):
351376
self._add_request_metadata(client_call_details)
352377
response_it = continuation(client_call_details, request_iterator)
353-
self._original_stream = response_it
354-
return response_it
378+
return self._StreamWrapper(response_it, self)
355379

356380

357381
class EchoMetadataClientGrpcAsyncInterceptor(
@@ -366,6 +390,37 @@ def __init__(self, key, value):
366390
self.request_metadata = []
367391
self.response_metadata = []
368392

393+
class _AsyncStreamWrapper:
394+
def __init__(self, stream, interceptor):
395+
self._stream = stream
396+
self._interceptor = interceptor
397+
398+
def __aiter__(self):
399+
return self._stream.__aiter__()
400+
401+
def __anext__(self):
402+
return self._stream.__anext__()
403+
404+
def cancel(self):
405+
return self._stream.cancel()
406+
407+
def done(self):
408+
return self._stream.done()
409+
410+
def add_done_callback(self, callback):
411+
return self._stream.add_done_callback(callback)
412+
413+
async def initial_metadata(self):
414+
return await self._stream.initial_metadata()
415+
416+
async def trailing_metadata(self):
417+
metadata = await self._stream.trailing_metadata()
418+
self._interceptor.response_metadata = metadata
419+
return metadata
420+
421+
async def debug_string(self):
422+
return await self._stream.debug_string()
423+
369424
async def _add_request_metadata(self, client_call_details):
370425
if client_call_details.metadata is not None:
371426
# As of gRPC 1.75.0 and newer,
@@ -393,22 +448,22 @@ async def intercept_unary_unary(self, continuation, client_call_details, request
393448

394449
async def intercept_unary_stream(self, continuation, client_call_details, request):
395450
self._add_request_metadata(client_call_details)
396-
response_it = continuation(client_call_details, request)
397-
return response_it
451+
response_it = await continuation(client_call_details, request)
452+
return self._AsyncStreamWrapper(response_it, self)
398453

399454
async def intercept_stream_unary(
400455
self, continuation, client_call_details, request_iterator
401456
):
402457
self._add_request_metadata(client_call_details)
403-
response = continuation(client_call_details, request_iterator)
458+
response = await continuation(client_call_details, request_iterator)
404459
return response
405460

406461
async def intercept_stream_stream(
407462
self, continuation, client_call_details, request_iterator
408463
):
409464
self._add_request_metadata(client_call_details)
410-
response_it = continuation(client_call_details, request_iterator)
411-
return response_it
465+
response_it = await continuation(client_call_details, request_iterator)
466+
return self._AsyncStreamWrapper(response_it, self)
412467

413468

414469
@pytest.fixture

0 commit comments

Comments
 (0)