|
1 | | -import atexit |
2 | 1 | import msal |
3 | 2 | import os |
| 3 | +import stat |
4 | 4 | import sys |
5 | 5 | import json |
6 | 6 | import logging |
7 | 7 | from .config import AUTHORITY_HOST_URI |
| 8 | +from msal_extensions.persistence import FilePersistence |
| 9 | +from msal_extensions.token_cache import PersistedTokenCache |
| 10 | + |
| 11 | +if not sys.platform.startswith("linux"): |
| 12 | + from msal_extensions import build_encrypted_persistence |
8 | 13 |
|
9 | 14 | HOME_DIR = os.path.expanduser("~") |
10 | 15 |
|
@@ -40,20 +45,42 @@ def __init__( |
40 | 45 | self.scope = resource_id + "/.default" |
41 | 46 | self.refresh_token = refresh_token |
42 | 47 |
|
43 | | - self.token_path = os.path.join( |
| 48 | + token_path = os.path.join( |
44 | 49 | HOME_DIR, ".sumo", str(resource_id) + ".token" |
45 | 50 | ) |
46 | | - |
47 | | - self.cache = None |
48 | | - |
49 | | - if not self.refresh_token: |
50 | | - self.cache = self.__load_cache() |
51 | | - atexit.register(self.__save_cache) |
| 51 | + self.token_path = token_path |
| 52 | + |
| 53 | + # https://github.com/AzureAD/microsoft-authentication-extensions-\ |
| 54 | + # for-python |
| 55 | + # Encryption not supported on linux servers like rgs, and |
| 56 | + # neither is common usage from many cluster nodes. |
| 57 | + # Encryption is supported on Windows and Mac. |
| 58 | + |
| 59 | + if sys.platform.startswith("linux"): |
| 60 | + persistence = FilePersistence(token_path) |
| 61 | + cache = PersistedTokenCache(persistence) |
| 62 | + else: |
| 63 | + if os.path.exists(token_path): |
| 64 | + encrypted_persistence = build_encrypted_persistence(token_path) |
| 65 | + try: |
| 66 | + token = encrypted_persistence.load() |
| 67 | + except Exception: |
| 68 | + # This code will encrypt an unencrypted existing file |
| 69 | + token = FilePersistence(token_path).load() |
| 70 | + with open(token_path, "w") as f: |
| 71 | + f.truncate() |
| 72 | + pass |
| 73 | + encrypted_persistence.save(token) |
| 74 | + pass |
| 75 | + pass |
| 76 | + |
| 77 | + persistence = build_encrypted_persistence(token_path) |
| 78 | + cache = PersistedTokenCache(persistence) |
52 | 79 |
|
53 | 80 | self.msal = msal.PublicClientApplication( |
54 | 81 | client_id=client_id, |
55 | 82 | authority=f"{AUTHORITY_HOST_URI}/{tenant_id}", |
56 | | - token_cache=self.cache, |
| 83 | + token_cache=cache, |
57 | 84 | ) |
58 | 85 |
|
59 | 86 | def get_token(self): |
@@ -118,39 +145,23 @@ def get_token(self): |
118 | 145 | % json.dumps(result, indent=4) |
119 | 146 | ) |
120 | 147 |
|
121 | | - self.__save_cache() |
| 148 | + if sys.platform.startswith("linux"): |
| 149 | + filemode = stat.filemode(os.stat(self.token_path).st_mode) |
| 150 | + if filemode != "-rw-------": |
| 151 | + os.chmod(self.token_path, 0o600) |
| 152 | + folder = os.path.dirname(self.token_path) |
| 153 | + foldermode = stat.filemode(os.stat(folder).st_mode) |
| 154 | + if foldermode != "drwx------": |
| 155 | + os.chmod(os.path.dirname(self.token_path), 0o700) |
122 | 156 |
|
123 | 157 | return result["access_token"] |
124 | 158 |
|
125 | | - def __load_cache(self): |
126 | | - """Load token cache from file. |
127 | | -
|
128 | | - Returns: |
129 | | - A msal friendly token cache object |
130 | | - """ |
131 | | - |
132 | | - cache = msal.SerializableTokenCache() |
133 | | - |
134 | | - if os.path.isfile(self.token_path): |
135 | | - with open(self.token_path, "r") as file: |
136 | | - cache.deserialize(file.read()) |
137 | | - |
138 | | - return cache |
139 | | - |
140 | | - def __save_cache(self): |
141 | | - """Write token cache to file.""" |
142 | | - |
143 | | - if self.cache.has_state_changed: |
144 | | - old_mask = os.umask(0o077) |
145 | | - |
146 | | - dir_path = os.path.dirname(self.token_path) |
147 | | - os.makedirs(dir_path, exist_ok=True) |
148 | | - |
149 | | - with open(self.token_path, "w") as file: |
150 | | - file.write(self.cache.serialize()) |
151 | | - |
152 | | - if not sys.platform.lower().startswith("win"): |
153 | | - os.chmod(self.token_path, 0o600) |
154 | | - os.chmod(dir_path, 0o700) |
155 | 159 |
|
156 | | - os.umask(old_mask) |
| 160 | +if __name__ == "__main__": |
| 161 | + auth = NewAuth( |
| 162 | + "1826bd7c-582f-4838-880d-5b4da5c3eea2", |
| 163 | + "88d2b022-3539-4dda-9e66-853801334a86", |
| 164 | + "3aa4a235-b6e2-48d5-9195-7fcf05b459b0", |
| 165 | + interactive=True, |
| 166 | + ) |
| 167 | + print(auth.get_token()) |
0 commit comments