Skip to content

Commit d62e8c6

Browse files
Added JWT obtaining logic
1 parent f58956a commit d62e8c6

2 files changed

Lines changed: 87 additions & 46 deletions

File tree

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,93 @@
1-
import datetime
2-
import decimal
3-
import re
4-
from enum import Enum
1+
import json
2+
import logging
53

4+
from conductor.asyncio_client.adapters.models import GenerateTokenRequest
5+
from conductor.asyncio_client.http import rest
66
from conductor.asyncio_client.http.api_client import ApiClient
7+
from conductor.asyncio_client.http.exceptions import ApiException
8+
9+
logger = logging.getLogger(__name__)
710

811

912
class ApiClientAdapter(ApiClient):
10-
def __deserialize(self, data, klass):
11-
"""Deserializes dict, list, str into an object.
13+
async def call_api(
14+
self,
15+
method,
16+
url,
17+
header_params=None,
18+
body=None,
19+
post_params=None,
20+
_request_timeout=None,
21+
) -> rest.RESTResponse:
22+
"""Makes the HTTP request (synchronous)
23+
:param method: Method to call.
24+
:param url: Path to method endpoint.
25+
:param header_params: Header parameters to be
26+
placed in the request header.
27+
:param body: Request body.
28+
:param post_params dict: Request post form parameters,
29+
for `application/x-www-form-urlencoded`, `multipart/form-data`.
30+
:param _request_timeout: timeout setting for this request.
31+
:return: RESTResponse
32+
"""
33+
34+
try:
35+
response_data = await self.rest_client.request(
36+
method,
37+
url,
38+
headers=header_params,
39+
body=body,
40+
post_params=post_params,
41+
_request_timeout=_request_timeout,
42+
)
43+
if response_data.status == 401:
44+
token = await self.refresh_authorization_token()
45+
header_params["X-Authorization"] = token
46+
response_data = await self.rest_client.request(
47+
method,
48+
url,
49+
headers=header_params,
50+
body=body,
51+
post_params=post_params,
52+
_request_timeout=_request_timeout,
53+
)
54+
except ApiException as e:
55+
raise e
56+
57+
return response_data
58+
59+
async def refresh_authorization_token(self):
60+
obtain_new_token_response = await self.obtain_new_token()
61+
token = obtain_new_token_response.get("token")
62+
self.configuration.api_key["api_key"] = token
63+
return token
64+
65+
async def obtain_new_token(self):
66+
body = GenerateTokenRequest(
67+
key_id=self.configuration.auth_key,
68+
key_secret=self.configuration.auth_secret,
69+
)
70+
_param = self.param_serialize(
71+
method="POST",
72+
resource_path="/token",
73+
body=body.to_dict(),
74+
)
75+
response = await self.call_api(
76+
*_param,
77+
)
78+
await response.read()
79+
return json.loads(response.data)
80+
81+
@classmethod
82+
def get_default(cls):
83+
"""Return new instance of ApiClient.
1284
13-
:param data: dict, list or str.
14-
:param klass: class literal, or string of class name.
85+
This method returns newly created, based on default constructor,
86+
object of ApiClient class or returns a copy of default
87+
ApiClient.
1588
16-
:return: object.
89+
:return: The ApiClient object.
1790
"""
18-
if data is None:
19-
return None
20-
21-
if isinstance(klass, str):
22-
if klass.startswith("List["):
23-
m = re.match(r"List\[(.*)]", klass)
24-
assert m is not None, "Malformed List type definition"
25-
sub_kls = m.group(1)
26-
return [self.__deserialize(sub_data, sub_kls) for sub_data in data]
27-
28-
if klass.startswith("Dict["):
29-
m = re.match(r"Dict\[([^,]*), (.*)]", klass)
30-
assert m is not None, "Malformed Dict type definition"
31-
sub_kls = m.group(2)
32-
return {k: self.__deserialize(v, sub_kls) for k, v in data.items()}
33-
34-
# convert str to class
35-
if klass in self.NATIVE_TYPES_MAPPING:
36-
klass = self.NATIVE_TYPES_MAPPING[klass]
37-
else:
38-
klass = getattr(conductor.asyncio_client.adapters.models, klass)
39-
40-
if klass in self.PRIMITIVE_TYPES:
41-
return self.__deserialize_primitive(data, klass)
42-
elif klass == object:
43-
return self.__deserialize_object(data)
44-
elif klass == datetime.date:
45-
return self.__deserialize_date(data)
46-
elif klass == datetime.datetime:
47-
return self.__deserialize_datetime(data)
48-
elif klass == decimal.Decimal:
49-
return decimal.Decimal(data)
50-
elif issubclass(klass, Enum):
51-
return self.__deserialize_enum(data, klass)
52-
else:
53-
return self.__deserialize_model(data, klass)
91+
if cls._default is None:
92+
cls._default = ApiClientAdapter()
93+
return cls._default

src/conductor/asyncio_client/http/api_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ def __deserialize(self, data, klass):
456456
if klass in self.NATIVE_TYPES_MAPPING:
457457
klass = self.NATIVE_TYPES_MAPPING[klass]
458458
else:
459-
klass = getattr(conductor.asyncio_client.http.models, klass)
459+
# Looking for our adapters instead of autogenerated models
460+
klass = getattr(conductor.asyncio_client.adapters.models, klass)
460461

461462
if klass in self.PRIMITIVE_TYPES:
462463
return self.__deserialize_primitive(data, klass)

0 commit comments

Comments
 (0)