Skip to content

Commit ec910b8

Browse files
authored
Auth refactor (#152)
* Authorization extended and refactored. * Ignore emacs backup files. * Renamed file decorators.py to _decorators.py --------- Co-authored-by: Raymond Wiker <rayw@equinor.com>
1 parent 585285c commit ec910b8

7 files changed

Lines changed: 272 additions & 227 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__pycache__/
22
*.py[cod]
33
.DS_Store
4+
*~
45
_venv
56
venv
67
dist
@@ -11,4 +12,4 @@ build
1112
/src/testing.py
1213
/docs/_build
1314
/docs/_static
14-
/src/sumo/wrapper/version.py
15+
/src/sumo/wrapper/version.py

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ setuptools>=49.2.1
55
pyjwt>=2.4.0
66
httpx>=0.25.0
77
tenacity>=8.2.3
8+
azure-identity>=1.14.0

src/sumo/wrapper/_auth_provider.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import msal
2+
import os
3+
import stat
4+
import sys
5+
import json
6+
import jwt
7+
import time
8+
from azure.identity import ManagedIdentityCredential
9+
10+
from msal_extensions.persistence import FilePersistence
11+
from msal_extensions.token_cache import PersistedTokenCache
12+
13+
if not sys.platform.startswith("linux"):
14+
from msal_extensions import build_encrypted_persistence
15+
16+
17+
def scope_for_resource(resource_id):
18+
return f"{resource_id}/.default offline_access"
19+
20+
21+
class AuthProvider:
22+
def __init__(self, resource_id):
23+
self._scope = scope_for_resource(resource_id)
24+
self._app = None
25+
return
26+
27+
def get_token(self):
28+
accounts = self._app.get_accounts()
29+
result = self._app.acquire_token_silent([self._scope], accounts[0])
30+
if "error" in result:
31+
raise ValueError(
32+
"Failed to silently acquire token. Err: %s"
33+
% json.dumps(result, indent=4)
34+
)
35+
# ELSE
36+
return result["access_token"]
37+
38+
pass
39+
40+
41+
class AuthProviderAccessToken(AuthProvider):
42+
def __init__(self, access_token):
43+
self._access_token = access_token
44+
payload = jwt.decode(access_token, options={"verify_signature": False})
45+
self._expires = payload["exp"]
46+
return
47+
48+
def get_token(self):
49+
if time.time() >= self._expires:
50+
raise ValueError("Access token has expired.")
51+
# ELSE
52+
return self._access_token
53+
54+
pass
55+
56+
57+
class AuthProviderRefreshToken(AuthProvider):
58+
def __init__(self, refresh_token, client_id, authority, resource_id):
59+
super().__init__(resource_id)
60+
self._app = msal.PublicClientApplication(
61+
client_id=client_id, authority=authority
62+
)
63+
self._scope = scope_for_resource(resource_id)
64+
self._app.acquire_token_by_refresh_token(refresh_token, [self._scope])
65+
return
66+
67+
pass
68+
69+
70+
def get_token_path(resource_id):
71+
return os.path.join(
72+
os.path.expanduser("~"), ".sumo", str(resource_id) + ".token"
73+
)
74+
75+
76+
def get_token_cache(resource_id):
77+
# https://github.com/AzureAD/microsoft-authentication-extensions-\
78+
# for-python
79+
# Encryption not supported on linux servers like rgs, and
80+
# neither is common usage from many cluster nodes.
81+
# Encryption is supported on Windows and Mac.
82+
83+
cache = None
84+
token_path = get_token_path(resource_id)
85+
if sys.platform.startswith("linux"):
86+
persistence = FilePersistence(token_path)
87+
cache = PersistedTokenCache(persistence)
88+
else:
89+
if os.path.exists(token_path):
90+
encrypted_persistence = build_encrypted_persistence(token_path)
91+
try:
92+
token = encrypted_persistence.load()
93+
except Exception:
94+
# This code will encrypt an unencrypted existing file
95+
token = FilePersistence(token_path).load()
96+
with open(token_path, "w") as f:
97+
f.truncate()
98+
pass
99+
encrypted_persistence.save(token)
100+
pass
101+
pass
102+
103+
persistence = build_encrypted_persistence(token_path)
104+
cache = PersistedTokenCache(persistence)
105+
pass
106+
return cache
107+
108+
109+
def protect_token_cache(resource_id):
110+
token_path = get_token_path(resource_id)
111+
112+
if sys.platform.startswith("linux"):
113+
filemode = stat.filemode(os.stat(token_path).st_mode)
114+
if filemode != "-rw-------":
115+
os.chmod(token_path, 0o600)
116+
folder = os.path.dirname(token_path)
117+
foldermode = stat.filemode(os.stat(folder).st_mode)
118+
if foldermode != "drwx------":
119+
os.chmod(os.path.dirname(token_path), 0o700)
120+
pass
121+
pass
122+
return
123+
pass
124+
125+
126+
class AuthProviderInteractive(AuthProvider):
127+
def __init__(self, client_id, authority, resource_id):
128+
super().__init__(resource_id)
129+
cache = get_token_cache(resource_id)
130+
self._app = msal.PublicClientApplication(
131+
client_id=client_id, authority=authority, token_cache=cache
132+
)
133+
134+
self._scope = scope_for_resource(resource_id)
135+
136+
if self.get_token() is None:
137+
self.login()
138+
pass
139+
return
140+
141+
def login(self):
142+
result = self._app.acquire_token_interactive([self._scope])
143+
144+
if "error" in result:
145+
raise ValueError(
146+
"Failed to acquire token interactively. Err: %s"
147+
% json.dumps(result, indent=4)
148+
)
149+
150+
return
151+
152+
pass
153+
154+
155+
class AuthProviderDeviceCode(AuthProvider):
156+
def __init__(self, client_id, authority, resource_id):
157+
super().__init__(resource_id)
158+
cache = get_token_cache(resource_id)
159+
self._app = msal.PublicClientApplication(
160+
client_id=client_id, authority=authority, token_cache=cache
161+
)
162+
self._resource_id = resource_id
163+
self._scope = scope_for_resource(resource_id)
164+
if self.get_token() is None:
165+
self.login()
166+
pass
167+
return
168+
169+
def login(self):
170+
flow = self._app.initiate_device_flow([self._scope])
171+
172+
if "error" in flow:
173+
raise ValueError(
174+
"Failed to create device flow. Err: %s"
175+
% json.dumps(flow, indent=4)
176+
)
177+
178+
print(flow["message"])
179+
result = self._app.acquire_token_by_device_flow(flow)
180+
181+
if "error" in result:
182+
raise ValueError(
183+
"Failed to acquire token by device flow. Err: %s"
184+
% json.dumps(result, indent=4)
185+
)
186+
187+
protect_token_cache(self._resource_id)
188+
189+
return
190+
191+
pass
192+
193+
194+
class AuthProviderManaged(AuthProvider):
195+
def __init__(self, resource_id):
196+
super().__init__(resource_id)
197+
self._app = ManagedIdentityCredential()
198+
self._scope = scope_for_resource(resource_id)
199+
return
200+
201+
def get_token(self):
202+
return self._app.get_token(self._scope)
203+
204+
pass
205+
206+
207+
def get_auth_provider(
208+
client_id,
209+
authority,
210+
resource_id,
211+
interactive=False,
212+
access_token=None,
213+
refresh_token=None,
214+
):
215+
if all(
216+
[
217+
os.getenv(x)
218+
for x in [
219+
"AZURE_FEDERATED_TOKEN_FILE",
220+
"AZURE_TENANT_ID",
221+
"AZURE_CLIENT_ID",
222+
"AZURE_AUTHORITY_HOST",
223+
]
224+
]
225+
):
226+
return AuthProviderManaged(resource_id)
227+
# ELSE
228+
if refresh_token:
229+
return AuthProviderRefreshToken(
230+
refresh_token, client_id, authority, resource_id
231+
)
232+
# ELSE
233+
if access_token:
234+
return AuthProviderAccessToken(access_token)
235+
# ELSE
236+
if interactive:
237+
return AuthProviderInteractive(client_id, authority, resource_id)
238+
# ELSE
239+
return AuthProviderDeviceCode(client_id, authority, resource_id)

src/sumo/wrapper/_blob_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import httpx
22

3-
from .decorators import raise_for_status, http_retry
3+
from ._decorators import raise_for_status, http_retry
44

55

66
class BlobClient:

0 commit comments

Comments
 (0)