Skip to content

Commit bd3c4e3

Browse files
authored
Encrypt token cache3 (#111)
Use msal-extensions 1.0 for secure storage of cached tokens on supported platforms.
1 parent 1dcf321 commit bd3c4e3

3 files changed

Lines changed: 56 additions & 43 deletions

File tree

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
msal~=1.17.0
2+
msal-extensions~=1.0.0
23
requests~=2.24.0
34
pytest~=6.1.1
45
PyYAML>=5.4
56
setuptools~=49.2.1
6-
pyjwt>=2.4.0
7+
pyjwt>=2.4.0

src/sumo/wrapper/_new_auth.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
import atexit
21
import msal
32
import os
3+
import stat
44
import sys
55
import json
66
import logging
77
from .config import AUTHORITY_HOST_URI
8+
from msal_extensions.persistence import FilePersistence
9+
from msal_extensions.token_cache import PersistedTokenCache
10+
11+
if not sys.platform.startswith("linux"):
12+
from msal_extensions import build_encrypted_persistence
813

914
HOME_DIR = os.path.expanduser("~")
1015

@@ -40,20 +45,42 @@ def __init__(
4045
self.scope = resource_id + "/.default"
4146
self.refresh_token = refresh_token
4247

43-
self.token_path = os.path.join(
48+
token_path = os.path.join(
4449
HOME_DIR, ".sumo", str(resource_id) + ".token"
4550
)
46-
47-
self.cache = None
48-
49-
if not self.refresh_token:
50-
self.cache = self.__load_cache()
51-
atexit.register(self.__save_cache)
51+
self.token_path = token_path
52+
53+
# https://github.com/AzureAD/microsoft-authentication-extensions-\
54+
# for-python
55+
# Encryption not supported on linux servers like rgs, and
56+
# neither is common usage from many cluster nodes.
57+
# Encryption is supported on Windows and Mac.
58+
59+
if sys.platform.startswith("linux"):
60+
persistence = FilePersistence(token_path)
61+
cache = PersistedTokenCache(persistence)
62+
else:
63+
if os.path.exists(token_path):
64+
encrypted_persistence = build_encrypted_persistence(token_path)
65+
try:
66+
token = encrypted_persistence.load()
67+
except Exception:
68+
# This code will encrypt an unencrypted existing file
69+
token = FilePersistence(token_path).load()
70+
with open(token_path, "w") as f:
71+
f.truncate()
72+
pass
73+
encrypted_persistence.save(token)
74+
pass
75+
pass
76+
77+
persistence = build_encrypted_persistence(token_path)
78+
cache = PersistedTokenCache(persistence)
5279

5380
self.msal = msal.PublicClientApplication(
5481
client_id=client_id,
5582
authority=f"{AUTHORITY_HOST_URI}/{tenant_id}",
56-
token_cache=self.cache,
83+
token_cache=cache,
5784
)
5885

5986
def get_token(self):
@@ -118,39 +145,23 @@ def get_token(self):
118145
% json.dumps(result, indent=4)
119146
)
120147

121-
self.__save_cache()
148+
if sys.platform.startswith("linux"):
149+
filemode = stat.filemode(os.stat(self.token_path).st_mode)
150+
if filemode != "-rw-------":
151+
os.chmod(self.token_path, 0o600)
152+
folder = os.path.dirname(self.token_path)
153+
foldermode = stat.filemode(os.stat(folder).st_mode)
154+
if foldermode != "drwx------":
155+
os.chmod(os.path.dirname(self.token_path), 0o700)
122156

123157
return result["access_token"]
124158

125-
def __load_cache(self):
126-
"""Load token cache from file.
127-
128-
Returns:
129-
A msal friendly token cache object
130-
"""
131-
132-
cache = msal.SerializableTokenCache()
133-
134-
if os.path.isfile(self.token_path):
135-
with open(self.token_path, "r") as file:
136-
cache.deserialize(file.read())
137-
138-
return cache
139-
140-
def __save_cache(self):
141-
"""Write token cache to file."""
142-
143-
if self.cache.has_state_changed:
144-
old_mask = os.umask(0o077)
145-
146-
dir_path = os.path.dirname(self.token_path)
147-
os.makedirs(dir_path, exist_ok=True)
148-
149-
with open(self.token_path, "w") as file:
150-
file.write(self.cache.serialize())
151-
152-
if not sys.platform.lower().startswith("win"):
153-
os.chmod(self.token_path, 0o600)
154-
os.chmod(dir_path, 0o700)
155159

156-
os.umask(old_mask)
160+
if __name__ == "__main__":
161+
auth = NewAuth(
162+
"1826bd7c-582f-4838-880d-5b4da5c3eea2",
163+
"88d2b022-3539-4dda-9e66-853801334a86",
164+
"3aa4a235-b6e2-48d5-9195-7fcf05b459b0",
165+
interactive=True,
166+
)
167+
print(auth.get_token())

tests/test_call_sumo_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def test_upload_search_delete_ensemble_child():
109109
surface_id = response_surface.json().get("objectid")
110110

111111
# Upload BLOB
112-
response_blob = _upload_blob(C=C, blob=B, object_id=surface_id)
112+
url = response_surface.json().get("blob_url")
113+
response_blob = _upload_blob(C=C, blob=B, object_id=surface_id, url=url)
113114
assert 200 <= response_blob.status_code <= 202
114115

115116
sleep(4)

0 commit comments

Comments
 (0)