-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathdurabletask_grpc_interceptor.py
More file actions
128 lines (110 loc) · 5.58 KB
/
Copy pathdurabletask_grpc_interceptor.py
File metadata and controls
128 lines (110 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from importlib.metadata import version
import grpc
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from durabletask.azuremanaged.internal.access_token_manager import (
AccessTokenManager,
AsyncAccessTokenManager,
)
from durabletask.internal.grpc_interceptor import (
DefaultAsyncClientInterceptorImpl,
DefaultClientInterceptorImpl,
_AsyncClientCallDetails,
_ClientCallDetails,
)
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
interceptor to add additional headers to all calls as needed."""
def __init__(
self,
token_credential: TokenCredential | None,
taskhub_name: str,
worker_id: str | None = None):
try:
# Get the version of the azuremanaged package
sdk_version = version('durabletask-azuremanaged')
except Exception:
# Fallback if version cannot be determined
sdk_version = "unknown"
user_agent = f"durabletask-python/{sdk_version}"
self._metadata = [
("taskhub", taskhub_name),
("x-user-agent", user_agent)] # 'user-agent' is a reserved header; use 'x-user-agent'
if worker_id is not None:
self._metadata.append(("workerid", worker_id))
super().__init__(self._metadata)
self._token_manager = None
if token_credential is not None:
self._token_credential = token_credential
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
access_token = self._token_manager.get_access_token()
if access_token is not None:
self._upsert_authorization_header(access_token.token)
def _upsert_authorization_header(self, token: str) -> None:
found = False
for i, (key, _) in enumerate(self._metadata):
if key.lower() == "authorization":
self._metadata[i] = ("authorization", f"Bearer {token}")
found = True
break
if not found:
self._metadata.append(("authorization", f"Bearer {token}"))
def _intercept_call(
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
call details."""
# Refresh the auth token if a credential was provided. The call to
# get_access_token() is generally cheap, checking the expiry time and returning
# the cached value without a network call when still valid.
if self._token_manager is not None:
access_token = self._token_manager.get_access_token()
if access_token is not None:
self._upsert_authorization_header(access_token.token)
return super()._intercept_call(client_call_details)
class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
"""Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels.
This class implements async gRPC interceptors to add DTS-specific headers
(task hub name, user agent, and authentication token) to all async calls."""
def __init__(self, token_credential: AsyncTokenCredential | None, taskhub_name: str):
try:
# Get the version of the azuremanaged package
sdk_version = version('durabletask-azuremanaged')
except Exception:
# Fallback if version cannot be determined
sdk_version = "unknown"
user_agent = f"durabletask-python/{sdk_version}"
self._metadata = [
("taskhub", taskhub_name),
("x-user-agent", user_agent)]
super().__init__(self._metadata)
# Token acquisition is deferred to the first _intercept_call invocation
# rather than happening in __init__, because get_token() on an
# AsyncTokenCredential is async and cannot be awaited in a constructor.
self._token_manager = None
if token_credential is not None:
self._token_credential = token_credential
self._token_manager = AsyncAccessTokenManager(token_credential=self._token_credential)
def _upsert_authorization_header(self, token: str) -> None:
found = False
for i, (key, _) in enumerate(self._metadata):
if key.lower() == "authorization":
self._metadata[i] = ("authorization", f"Bearer {token}")
found = True
break
if not found:
self._metadata.append(("authorization", f"Bearer {token}"))
async def _intercept_call(
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
call details."""
# Refresh the auth token if a credential was provided. The call to
# get_access_token() is generally cheap, checking the expiry time and returning
# the cached value without a network call when still valid.
if self._token_manager is not None:
access_token = await self._token_manager.get_access_token()
if access_token is not None:
self._upsert_authorization_header(access_token.token)
return await super()._intercept_call(client_call_details)