Skip to content

Commit f6e6114

Browse files
authored
Merge pull request #66 from equinor/new-auth-class
new auth class for SumoClient
2 parents 833d70b + 0d19ef5 commit f6e6114

4 files changed

Lines changed: 150 additions & 45 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
msal~=1.5.1
1+
msal~=1.17.0
22
requests~=2.24.0
33
pytest~=6.1.1
44
PyYAML>=5.4

src/sumo/wrapper/_new_auth.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import atexit
2+
import msal
3+
import os
4+
import sys
5+
import json
6+
from .config import AUTHORITY_HOST_URI
7+
8+
HOME_DIR = os.path.expanduser("~")
9+
10+
class NewAuth:
11+
def __init__(
12+
self,
13+
client_id,
14+
resource_id,
15+
tenant_id,
16+
interactive=False,
17+
refresh_token=None
18+
):
19+
self.interactive = interactive
20+
self.scope = resource_id + "/.default"
21+
self.refresh_token = refresh_token
22+
23+
self.token_path = os.path.join(
24+
HOME_DIR, ".sumo", str(resource_id) + ".token"
25+
)
26+
27+
self.cache = None
28+
29+
if not self.refresh_token:
30+
self.cache = self.__load_cache()
31+
atexit.register(self.__save_cache)
32+
33+
self.msal = msal.PublicClientApplication(
34+
client_id=client_id,
35+
authority=f"{AUTHORITY_HOST_URI}/{tenant_id}",
36+
token_cache=self.cache
37+
)
38+
39+
40+
def get_token(self):
41+
accounts = self.msal.get_accounts()
42+
result = None
43+
44+
if accounts:
45+
result = self.msal.acquire_token_silent([self.scope], account=accounts[0])
46+
47+
if not result:
48+
if self.refresh_token:
49+
result = self.msal.acquire_token_by_refresh_token(self.refresh_token, [self.scope])
50+
51+
if "error" in result:
52+
raise ValueError(
53+
"Failed to acquire token by refresh token. Err: %s" % json.dumps(result, indent=4)
54+
)
55+
else:
56+
if self.interactive:
57+
result = self.msal.acquire_token_interactive([self.scope])
58+
59+
if "error" in result:
60+
raise ValueError(
61+
"Failed to acquire token interactively. Err: %s" % json.dumps(result, indent=4)
62+
)
63+
else:
64+
flow = self.msal.initiate_device_flow([self.scope])
65+
66+
if "error" in flow:
67+
raise ValueError(
68+
"Failed to create device flow. Err: %s" % json.dumps(flow, indent=4)
69+
)
70+
71+
print(flow["message"])
72+
result = self.msal.acquire_token_by_device_flow(flow)
73+
74+
if "error" in result:
75+
raise ValueError(
76+
"Failed to acquire token by device flow. Err: %s" % json.dumps(result, indent=4)
77+
)
78+
79+
return result["access_token"]
80+
81+
82+
def __load_cache(self):
83+
cache = msal.SerializableTokenCache()
84+
85+
if os.path.isfile(self.token_path):
86+
with open(self.token_path, "r") as file:
87+
cache.deserialize(file.read())
88+
89+
return cache
90+
91+
92+
def __save_cache(self):
93+
if self.cache.has_state_changed:
94+
old_mask = os.umask(0o077)
95+
96+
dir_path = os.path.dirname(self.token_path)
97+
os.makedirs(dir_path, exist_ok=True)
98+
99+
with open(self.token_path, "w") as file:
100+
file.write(self.cache.serialize())
101+
102+
if not sys.platform.lower().startswith("win"):
103+
os.chmod(self.token_path, 0o600)
104+
os.chmod(dir_path, 0o700)
105+
106+
os.umask(old_mask)

src/sumo/wrapper/_sumo_client.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,65 @@
11
import requests
2-
import logging
2+
import jwt
3+
import time
34

4-
from .config import APP_REGISTRATION, TENANT_ID, AUTHORITY_HOST_URI
5-
from ._auth import Auth
5+
from .config import APP_REGISTRATION, TENANT_ID
6+
from ._new_auth import NewAuth
67
from ._request_error import AuthenticationError, TransientError, PermanentError
78

8-
logging.basicConfig(
9-
format="%(asctime)s %(levelname)-8s %(message)s",
10-
datefmt="%Y-%m-%d %H:%M:%S"
11-
)
12-
13-
149
class SumoClient:
1510
def __init__(
1611
self,
1712
env,
18-
access_token=None,
19-
logging_level='INFO',
20-
write_back=False
13+
token=None,
14+
interactive=False
2115
):
22-
self.env = env
23-
self.user_provided_access_token = access_token
24-
25-
self.logger = logging.getLogger(__name__)
26-
self.logger.setLevel(level=logging_level)
27-
2816
if env not in APP_REGISTRATION:
2917
raise ValueError(f"Invalid environment: {env}")
3018

31-
self.client_id = APP_REGISTRATION[env]['CLIENT_ID']
32-
self.resource_id = APP_REGISTRATION[env]['RESOURCE_ID']
33-
self.authority_uri = AUTHORITY_HOST_URI + '/' + TENANT_ID
34-
35-
if not self.user_provided_access_token:
36-
self.auth = Auth(
37-
client_id=self.client_id,
38-
resource_id=self.resource_id,
39-
authority=self.authority_uri,
40-
writeback=write_back
41-
)
42-
43-
self.access_token = self.auth.get_token()
19+
self.access_token = None
20+
self.access_token_expires = None
21+
self.refresh_token = None
22+
23+
if token:
24+
payload = self.__decode_token(token)
25+
26+
if payload:
27+
self.access_token = token
28+
self.access_token_expires = payload["exp"]
29+
else:
30+
self.refresh_token = token
31+
32+
self.auth = NewAuth(
33+
client_id=APP_REGISTRATION[env]['CLIENT_ID'],
34+
resource_id=APP_REGISTRATION[env]['RESOURCE_ID'],
35+
tenant_id=TENANT_ID,
36+
interactive=interactive,
37+
refresh_token=self.refresh_token
38+
)
4439

4540
if env == "localhost":
4641
self.base_url = f"http://localhost:8084/api/v1"
4742
else:
4843
self.base_url = f"https://main-sumo-{env}.radix.equinor.com/api/v1"
4944

45+
46+
def __decode_token(self, token):
47+
try:
48+
payload = jwt.decode(token, options={"verify_signature": False})
49+
return payload
50+
except:
51+
return None
52+
53+
5054
def _retrieve_token(self):
51-
if self.user_provided_access_token:
52-
self.logger.debug("User provided token exists, returning token")
53-
return self.user_provided_access_token
54-
else:
55-
if self.auth.is_token_expired():
56-
self.logger.debug("Token is expired, regenerating")
57-
self.access_token = self.auth.get_token()
55+
if self.access_token:
56+
if self.access_token_expires <= int(time.time()):
57+
raise ValueError("Access_token has expired")
58+
else:
59+
return self.access_token
60+
61+
return self.auth.get_token()
5862

59-
self.logger.debug("returning self.access_token from _retrieve_token")
60-
return self.access_token
6163

6264
def _process_params(self, params_dict):
6365
prefixed_params = {}
@@ -174,9 +176,6 @@ def _raise_request_error_exception(self, code, message):
174176
Raise the proper authentication error according to the code received from sumo.
175177
"""
176178

177-
self.logger.debug("code: %s", code)
178-
self.logger.debug("message: %s", message)
179-
180179
if 503 <= code <= 504 or code == 404 or code == 500:
181180
raise TransientError(code, message)
182181
elif 401 <= code <= 403:

tests/test_sumo_thin_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class Connection:
1515
def __init__(self):
16-
self.api = SumoClient(env="dev", logging_level="DEBUG")
16+
self.api = SumoClient(env="dev")
1717

1818

1919
def _upload_parent_object(C, json):

0 commit comments

Comments
 (0)