-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.py
More file actions
323 lines (272 loc) · 11.1 KB
/
auth.py
File metadata and controls
323 lines (272 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Authentication module for Azure API access
Handles auth file parsing and token refresh for Graph (group names).
"""
import base64
import json
import urllib.error
import urllib.parse
import urllib.request
from pathlib import Path
from typing import Optional
# Graph scope for token refresh (to resolve group display names)
GRAPH_SCOPE = "https://graph.microsoft.com/.default"
ARM_SCOPE = "https://management.azure.com/.default"
TOKEN_ENDPOINT_TEMPLATE = "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token"
# Client IDs for brokered authentication
BROKER_CLIENT_ID = "c44b4083-3bb0-49c1-b47d-974e53cbdf3c"
TARGET_CLIENT_ID = "74658136-14ec-4630-ad9b-26e160ff0fc6"
# Client ID for PIM API
PIM_CLIENT_ID = "50aaa389-5a33-4f1a-91d7-2c45ecd8dac8"
class AuthManager:
"""Manages authentication for Azure API and optional Graph token via refresh."""
def __init__(
self,
auth_file: Optional[Path] = None,
access_token: Optional[str] = None,
):
self.auth_file = auth_file
self.access_token = access_token
self._access_token_cache: Optional[str] = None
self._auth_data: Optional[dict] = None # Parsed auth file contents
self._graph_token_cache: Optional[str] = None
self._pim_graph_token_cache: Optional[str] = None
def get_access_token(self) -> str:
"""
Get Azure (ARM) access token from auth file or direct input.
If the auth file was obtained with the broker client ID, this performs
a brokered refresh to get an ARM access token.
Returns:
str: Access token for Azure Resource Manager API.
Raises:
FileNotFoundError: If auth file doesn't exist
ValueError: If auth file format is invalid or no token provided
"""
if self._access_token_cache is not None:
return self._access_token_cache
if self.access_token:
self._access_token_cache = self.access_token
return self._access_token_cache
if not self.auth_file:
raise ValueError(
"No authentication method provided. Use --auth-file or --access-token"
)
data = self._load_auth_data()
if not data:
raise ValueError("Invalid or empty auth file")
client_id = (
data.get("_clientId")
or data.get("clientId")
or data.get("client_id")
)
# If using broker, we must refresh to get a usable ARM token
if client_id == BROKER_CLIENT_ID:
refresh_token = data.get("refreshToken") or data.get("refresh_token")
if not refresh_token:
raise ValueError("Brokered auth requires a refreshToken in the auth file")
tenant = self._get_tenant(data)
token = self._refresh_brokered_for_scope(refresh_token, tenant, ARM_SCOPE)
if not token:
raise ValueError("Brokered refresh for ARM token failed")
self._access_token_cache = token
return token
# Default: read token directly from file
self._access_token_cache = self._read_access_token()
return self._access_token_cache
def get_graph_token(self) -> Optional[str]:
"""
Get Microsoft Graph API token, using refresh token if available.
Handles standard and brokered refresh flows.
"""
if self._graph_token_cache is not None:
return self._graph_token_cache
if not self.auth_file or not self.auth_file.exists():
return None
data = self._load_auth_data()
if not data:
return None
# Explicit Graph token in file (optional)
explicit = data.get("graph_access_token") or data.get("graphAccessToken")
if explicit:
self._graph_token_cache = explicit
return self._graph_token_cache
refresh_token = data.get("refreshToken") or data.get("refresh_token")
if not refresh_token:
return None
client_id = (
data.get("_clientId")
or data.get("clientId")
or data.get("client_id")
or "d3590ed6-52b3-4102-aeff-aad2292ab01c" # Default to non-brokered
)
tenant = self._get_tenant(data)
if client_id == BROKER_CLIENT_ID:
token = self._refresh_brokered_for_scope(refresh_token, tenant, GRAPH_SCOPE)
else:
token = self._refresh_non_brokered_for_scope(refresh_token, client_id, tenant, GRAPH_SCOPE)
if token:
self._graph_token_cache = token
return token
def get_pim_graph_token(self) -> Optional[str]:
"""
Get a separate Microsoft Graph API token for PIM operations.
"""
if self._pim_graph_token_cache is not None:
return self._pim_graph_token_cache
if not self.auth_file or not self.auth_file.exists():
return None
data = self._load_auth_data()
if not data:
return None
refresh_token = data.get("refreshToken") or data.get("refresh_token")
if not refresh_token:
return None
tenant = self._get_tenant(data)
client_id = (
data.get("_clientId")
or data.get("clientId")
or data.get("client_id")
)
if client_id == BROKER_CLIENT_ID:
token = self._refresh_brokered_for_scope(refresh_token, tenant, GRAPH_SCOPE, PIM_CLIENT_ID)
else:
token = self._refresh_non_brokered_for_scope(refresh_token, PIM_CLIENT_ID, tenant, GRAPH_SCOPE)
if token:
self._pim_graph_token_cache = token
return token
def _get_tenant(self, data: dict) -> str:
"""Extract tenant from auth data, falling back to access token."""
tenant = data.get("tenantId") or data.get("tenant_id") or ""
if not tenant.strip():
tenant = self._tenant_from_access_token(data.get("accessToken") or data.get("access_token"))
return tenant or "common"
def _load_auth_data(self) -> Optional[dict]:
"""Load and return raw auth file JSON, or None on error."""
if self._auth_data is not None:
return self._auth_data
if not self.auth_file or not self.auth_file.exists():
return None
try:
with open(self.auth_file, "r", encoding="utf-8") as f:
self._auth_data = json.load(f)
return self._auth_data
except (json.JSONDecodeError, TypeError, OSError):
return None
def _read_access_token(self) -> str:
"""Read ARM access token from auth file. Raises on invalid format."""
if not self.auth_file.exists():
raise FileNotFoundError(
f"Auth file not found: {self.auth_file}\n"
"Create it or provide --access-token directly"
)
data = self._load_auth_data()
if not data:
raise ValueError("Invalid or empty auth file")
# Flat format (e.g. ROADtools / Azure CLI style)
token = data.get("accessToken") or data.get("access_token")
if token:
return token
# Legacy nested format
if "tokendata" in data:
td = data["tokendata"]
if isinstance(td, list) and len(td) > 0:
token = td[0].get("access_token")
elif isinstance(td, dict):
token = td.get("access_token")
else:
raise ValueError("Invalid tokendata format in auth file")
if token:
return token
raise ValueError(
"Auth file must contain 'accessToken' (or 'access_token' / tokendata)"
)
@staticmethod
def _tenant_from_access_token(access_token: Optional[str]) -> str:
"""Extract tenant id (tid) from JWT payload without verifying signature."""
if not access_token or "." not in access_token:
return ""
try:
parts = access_token.split(".")
if len(parts) < 2:
return ""
payload = parts[1]
# Base64url: add padding if needed
pad = 4 - len(payload) % 4
if pad != 4:
payload += "=" * pad
raw = base64.urlsafe_b64decode(payload)
data = json.loads(raw)
return data.get("tid") or ""
except Exception:
return ""
@staticmethod
def _refresh_non_brokered_for_scope(
refresh_token: str,
client_id: str,
tenant: str,
scope: str,
) -> Optional[str]:
"""Exchange refresh_token for an access token with the given scope."""
url = TOKEN_ENDPOINT_TEMPLATE.format(tenant=tenant or "common")
body = (
f"grant_type=refresh_token"
f"&refresh_token={urllib.parse.quote(refresh_token, safe='')}"
f"&client_id={urllib.parse.quote(client_id, safe='')}"
f"&scope={urllib.parse.quote(scope, safe='')}"
).encode("utf-8")
req = urllib.request.Request(
url,
data=body,
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("access_token")
except urllib.error.HTTPError:
return None
except (urllib.error.URLError, json.JSONDecodeError, KeyError):
return None
except Exception:
return None
@staticmethod
def _refresh_brokered_for_scope(
refresh_token: str,
tenant: str,
scope: str,
target_client_id: str = TARGET_CLIENT_ID,
) -> Optional[str]:
"""Exchange refresh_token for an access token with the given scope using brokered auth."""
url = TOKEN_ENDPOINT_TEMPLATE.format(tenant=tenant or "common")
brokered_redirect = f"brk-{BROKER_CLIENT_ID}://portal.azure.com"
payload = {
"grant_type": "refresh_token",
"client_id": target_client_id,
"refresh_token": refresh_token,
"brk_client_id": BROKER_CLIENT_ID,
"scope": scope,
"redirect_uri": brokered_redirect
}
body = urllib.parse.urlencode(payload).encode("utf-8")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Origin": "https://portal.azure.com",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
}
req = urllib.request.Request(
url,
data=body,
method="POST",
headers=headers,
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("access_token")
except urllib.error.HTTPError:
return None
except (urllib.error.URLError, json.JSONDecodeError, KeyError):
return None
except Exception:
return None