Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit fbaff28

Browse files
committed
feat: Add REST Interceptors to support reading metadata
1 parent 1fb1c76 commit fbaff28

File tree

13 files changed

+996
-0
lines changed

13 files changed

+996
-0
lines changed

gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,20 @@ class {{ async_method_name_prefix }}{{ service.name }}RestInterceptor:
389389
"""
390390
return response
391391

392+
{% if not method.server_streaming %}
393+
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: {{method.output.ident}}, {{ client_method_metadata_argument() }}) -> Tuple[{{method.output.ident}}, {{ client_method_metadata_type() }}]:
394+
{% else %}
395+
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_argument() }}) -> Tuple[rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_type() }}]:
392396
{% endif %}
397+
"""Post-rpc interceptor for {{ method.name|snake_case }}
398+
399+
Override in a subclass to either manipulate or read, either the response
400+
or metadata after it is returned by the {{ service.name }} server but before
401+
it is returned to user code.
402+
"""
403+
return response, metadata
404+
405+
{% endif %}{# not method.void #}
393406
{% endfor %}
394407

395408
{% for name, signature in api.mixin_api_signatures.items() %}

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ class {{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
256256
{% endif %}{# method.lro #}
257257
{#- TODO(https://github.com/googleapis/gapic-generator-python/issues/2274): Add debug log before intercepting a request #}
258258
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
259+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
260+
resp, _ = self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
259261
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
260262
{% if not method.server_streaming %}
261263
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
217217
json_format.Parse(content, pb_resp, ignore_unknown_fields=True)
218218
{% endif %}{# if method.server_streaming #}
219219
resp = await self._interceptor.post_{{ method.name|snake_case }}(resp)
220+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
221+
resp, _ = await self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
220222
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
221223
{% if not method.server_streaming %}
222224
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,11 +2218,13 @@ def test_initialize_client_w_{{transport_name}}():
22182218
{% endif %}
22192219
{% if not method.void %}
22202220
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
2221+
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}_with_metadata") as post_with_metadata, \
22212222
{% endif %}
22222223
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
22232224
pre.assert_not_called()
22242225
{% if not method.void %}
22252226
post.assert_not_called()
2227+
post_with_metadata.assert_not_called()
22262228
{% endif %}
22272229
{% if method.input.ident.is_proto_plus_type %}
22282230
pb_message = {{ method.input.ident }}.pb({{ method.input.ident }}())
@@ -2265,13 +2267,15 @@ def test_initialize_client_w_{{transport_name}}():
22652267
pre.return_value = request, metadata
22662268
{% if not method.void %}
22672269
post.return_value = {{ method.output.ident }}()
2270+
post_with_metadata.return_value = {{ method.output.ident }}(), metadata
22682271
{% endif %}
22692272

22702273
{{await_prefix}}client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
22712274

22722275
pre.assert_called_once()
22732276
{% if not method.void %}
22742277
post.assert_called_once()
2278+
post_with_metadata.assert_called_once()
22752279
{% endif %}
22762280
{% endif %}{# end 'grpc' in transport #}
22772281
{% endmacro%}{# inteceptor_class_test #}

tests/integration/goldens/asset/google/cloud/asset_v1/services/asset_service/transports/rest.py

Lines changed: 231 additions & 0 deletions
Large diffs are not rendered by default.

tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py

Lines changed: 84 additions & 0 deletions
Large diffs are not rendered by default.

tests/integration/goldens/credentials/google/iam/credentials_v1/services/iam_credentials/transports/rest.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ def post_generate_access_token(self, response: common.GenerateAccessTokenRespons
127127
"""
128128
return response
129129

130+
def post_generate_access_token_with_metadata(self, response: common.GenerateAccessTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateAccessTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
131+
"""Post-rpc interceptor for generate_access_token
132+
133+
Override in a subclass to either manipulate or read, either the response
134+
or metadata after it is returned by the IAMCredentials server but before
135+
it is returned to user code.
136+
"""
137+
return response, metadata
138+
130139
def pre_generate_id_token(self, request: common.GenerateIdTokenRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
131140
"""Pre-rpc interceptor for generate_id_token
132141
@@ -144,6 +153,15 @@ def post_generate_id_token(self, response: common.GenerateIdTokenResponse) -> co
144153
"""
145154
return response
146155

156+
def post_generate_id_token_with_metadata(self, response: common.GenerateIdTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
157+
"""Post-rpc interceptor for generate_id_token
158+
159+
Override in a subclass to either manipulate or read, either the response
160+
or metadata after it is returned by the IAMCredentials server but before
161+
it is returned to user code.
162+
"""
163+
return response, metadata
164+
147165
def pre_sign_blob(self, request: common.SignBlobRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
148166
"""Pre-rpc interceptor for sign_blob
149167
@@ -161,6 +179,15 @@ def post_sign_blob(self, response: common.SignBlobResponse) -> common.SignBlobRe
161179
"""
162180
return response
163181

182+
def post_sign_blob_with_metadata(self, response: common.SignBlobResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
183+
"""Post-rpc interceptor for sign_blob
184+
185+
Override in a subclass to either manipulate or read, either the response
186+
or metadata after it is returned by the IAMCredentials server but before
187+
it is returned to user code.
188+
"""
189+
return response, metadata
190+
164191
def pre_sign_jwt(self, request: common.SignJwtRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
165192
"""Pre-rpc interceptor for sign_jwt
166193
@@ -178,6 +205,15 @@ def post_sign_jwt(self, response: common.SignJwtResponse) -> common.SignJwtRespo
178205
"""
179206
return response
180207

208+
def post_sign_jwt_with_metadata(self, response: common.SignJwtResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
209+
"""Post-rpc interceptor for sign_jwt
210+
211+
Override in a subclass to either manipulate or read, either the response
212+
or metadata after it is returned by the IAMCredentials server but before
213+
it is returned to user code.
214+
"""
215+
return response, metadata
216+
181217

182218
@dataclasses.dataclass
183219
class IAMCredentialsRestStub:
@@ -375,6 +411,8 @@ def __call__(self,
375411
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
376412

377413
resp = self._interceptor.post_generate_access_token(resp)
414+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
415+
resp, _ = self._interceptor.post_generate_access_token_with_metadata(resp, response_metadata)
378416
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
379417
try:
380418
response_payload = common.GenerateAccessTokenResponse.to_json(response)
@@ -495,6 +533,8 @@ def __call__(self,
495533
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
496534

497535
resp = self._interceptor.post_generate_id_token(resp)
536+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
537+
resp, _ = self._interceptor.post_generate_id_token_with_metadata(resp, response_metadata)
498538
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
499539
try:
500540
response_payload = common.GenerateIdTokenResponse.to_json(response)
@@ -615,6 +655,8 @@ def __call__(self,
615655
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
616656

617657
resp = self._interceptor.post_sign_blob(resp)
658+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
659+
resp, _ = self._interceptor.post_sign_blob_with_metadata(resp, response_metadata)
618660
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
619661
try:
620662
response_payload = common.SignBlobResponse.to_json(response)
@@ -735,6 +777,8 @@ def __call__(self,
735777
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
736778

737779
resp = self._interceptor.post_sign_jwt(resp)
780+
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
781+
resp, _ = self._interceptor.post_sign_jwt_with_metadata(resp, response_metadata)
738782
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
739783
try:
740784
response_payload = common.SignJwtResponse.to_json(response)

tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3162,9 +3162,11 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
31623162
with mock.patch.object(type(client.transport._session), "request") as req, \
31633163
mock.patch.object(path_template, "transcode") as transcode, \
31643164
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token") as post, \
3165+
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token_with_metadata") as post_with_metadata, \
31653166
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_access_token") as pre:
31663167
pre.assert_not_called()
31673168
post.assert_not_called()
3169+
post_with_metadata.assert_not_called()
31683170
pb_message = common.GenerateAccessTokenRequest.pb(common.GenerateAccessTokenRequest())
31693171
transcode.return_value = {
31703172
"method": "post",
@@ -3186,11 +3188,13 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
31863188
]
31873189
pre.return_value = request, metadata
31883190
post.return_value = common.GenerateAccessTokenResponse()
3191+
post_with_metadata.return_value = common.GenerateAccessTokenResponse(), metadata
31893192

31903193
client.generate_access_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
31913194

31923195
pre.assert_called_once()
31933196
post.assert_called_once()
3197+
post_with_metadata.assert_called_once()
31943198

31953199

31963200
def test_generate_id_token_rest_bad_request(request_type=common.GenerateIdTokenRequest):
@@ -3264,9 +3268,11 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
32643268
with mock.patch.object(type(client.transport._session), "request") as req, \
32653269
mock.patch.object(path_template, "transcode") as transcode, \
32663270
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token") as post, \
3271+
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token_with_metadata") as post_with_metadata, \
32673272
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_id_token") as pre:
32683273
pre.assert_not_called()
32693274
post.assert_not_called()
3275+
post_with_metadata.assert_not_called()
32703276
pb_message = common.GenerateIdTokenRequest.pb(common.GenerateIdTokenRequest())
32713277
transcode.return_value = {
32723278
"method": "post",
@@ -3288,11 +3294,13 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
32883294
]
32893295
pre.return_value = request, metadata
32903296
post.return_value = common.GenerateIdTokenResponse()
3297+
post_with_metadata.return_value = common.GenerateIdTokenResponse(), metadata
32913298

32923299
client.generate_id_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
32933300

32943301
pre.assert_called_once()
32953302
post.assert_called_once()
3303+
post_with_metadata.assert_called_once()
32963304

32973305

32983306
def test_sign_blob_rest_bad_request(request_type=common.SignBlobRequest):
@@ -3368,9 +3376,11 @@ def test_sign_blob_rest_interceptors(null_interceptor):
33683376
with mock.patch.object(type(client.transport._session), "request") as req, \
33693377
mock.patch.object(path_template, "transcode") as transcode, \
33703378
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob") as post, \
3379+
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob_with_metadata") as post_with_metadata, \
33713380
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_blob") as pre:
33723381
pre.assert_not_called()
33733382
post.assert_not_called()
3383+
post_with_metadata.assert_not_called()
33743384
pb_message = common.SignBlobRequest.pb(common.SignBlobRequest())
33753385
transcode.return_value = {
33763386
"method": "post",
@@ -3392,11 +3402,13 @@ def test_sign_blob_rest_interceptors(null_interceptor):
33923402
]
33933403
pre.return_value = request, metadata
33943404
post.return_value = common.SignBlobResponse()
3405+
post_with_metadata.return_value = common.SignBlobResponse(), metadata
33953406

33963407
client.sign_blob(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
33973408

33983409
pre.assert_called_once()
33993410
post.assert_called_once()
3411+
post_with_metadata.assert_called_once()
34003412

34013413

34023414
def test_sign_jwt_rest_bad_request(request_type=common.SignJwtRequest):
@@ -3472,9 +3484,11 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
34723484
with mock.patch.object(type(client.transport._session), "request") as req, \
34733485
mock.patch.object(path_template, "transcode") as transcode, \
34743486
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt") as post, \
3487+
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt_with_metadata") as post_with_metadata, \
34753488
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_jwt") as pre:
34763489
pre.assert_not_called()
34773490
post.assert_not_called()
3491+
post_with_metadata.assert_not_called()
34783492
pb_message = common.SignJwtRequest.pb(common.SignJwtRequest())
34793493
transcode.return_value = {
34803494
"method": "post",
@@ -3496,11 +3510,13 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
34963510
]
34973511
pre.return_value = request, metadata
34983512
post.return_value = common.SignJwtResponse()
3513+
post_with_metadata.return_value = common.SignJwtResponse(), metadata
34993514

35003515
client.sign_jwt(request, metadata=[("key", "val"), ("cephalopod", "squid"),])
35013516

35023517
pre.assert_called_once()
35033518
post.assert_called_once()
3519+
post_with_metadata.assert_called_once()
35043520

35053521
def test_initialize_client_w_rest():
35063522
client = IAMCredentialsClient(

0 commit comments

Comments
 (0)