|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import Any, Dict, Mapping, Optional, Type, cast |
| 5 | +from typing import Any, Dict, List, Mapping, Optional, Type, cast |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class WorkOSError(Exception): |
@@ -142,21 +142,196 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: |
142 | 142 | super().__init__(*args, **kwargs) |
143 | 143 |
|
144 | 144 |
|
145 | | -class EmailVerificationRequiredError(AuthorizationError): |
| 145 | +class AuthenticationFlowError(AuthorizationError): |
| 146 | + """Raised when authentication requires an additional step. |
| 147 | +
|
| 148 | + All auth-flow 403 errors carry a pending_authentication_token that |
| 149 | + must be passed to the next step in the authentication flow. |
| 150 | + """ |
| 151 | + |
| 152 | + pending_authentication_token: Optional[str] |
| 153 | + |
| 154 | + def __init__( |
| 155 | + self, |
| 156 | + *args: Any, |
| 157 | + pending_authentication_token: Optional[str] = None, |
| 158 | + **kwargs: Any, |
| 159 | + ) -> None: |
| 160 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 161 | + if pending_authentication_token is None and response_json is not None: |
| 162 | + pending_authentication_token = cast( |
| 163 | + Optional[str], response_json.get("pending_authentication_token") |
| 164 | + ) |
| 165 | + super().__init__(*args, **kwargs) |
| 166 | + self.pending_authentication_token = pending_authentication_token |
| 167 | + |
| 168 | + |
| 169 | +class EmailVerificationRequiredError(AuthenticationFlowError): |
146 | 170 | """Raised when email verification is required before authentication.""" |
147 | 171 |
|
148 | 172 | email_verification_id: Optional[str] |
| 173 | + email: Optional[str] |
149 | 174 |
|
150 | 175 | def __init__( |
151 | | - self, *args: Any, email_verification_id: Optional[str] = None, **kwargs: Any |
| 176 | + self, |
| 177 | + *args: Any, |
| 178 | + email_verification_id: Optional[str] = None, |
| 179 | + email: Optional[str] = None, |
| 180 | + **kwargs: Any, |
152 | 181 | ) -> None: |
153 | 182 | response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
154 | 183 | if email_verification_id is None and response_json is not None: |
155 | 184 | email_verification_id = cast( |
156 | 185 | Optional[str], response_json.get("email_verification_id") |
157 | 186 | ) |
| 187 | + if email is None and response_json is not None: |
| 188 | + email = cast(Optional[str], response_json.get("email")) |
158 | 189 | super().__init__(*args, **kwargs) |
159 | 190 | self.email_verification_id = email_verification_id |
| 191 | + self.email = email |
| 192 | + |
| 193 | + |
| 194 | +class MfaEnrollmentError(AuthenticationFlowError): |
| 195 | + """Raised when MFA enrollment is required.""" |
| 196 | + |
| 197 | + user: Optional[Dict[str, Any]] |
| 198 | + |
| 199 | + def __init__( |
| 200 | + self, *args: Any, user: Optional[Dict[str, Any]] = None, **kwargs: Any |
| 201 | + ) -> None: |
| 202 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 203 | + if user is None and response_json is not None: |
| 204 | + user = cast(Optional[Dict[str, Any]], response_json.get("user")) |
| 205 | + super().__init__(*args, **kwargs) |
| 206 | + self.user = user |
| 207 | + |
| 208 | + |
| 209 | +class MfaChallengeError(AuthenticationFlowError): |
| 210 | + """Raised when an MFA challenge must be completed.""" |
| 211 | + |
| 212 | + user: Optional[Dict[str, Any]] |
| 213 | + authentication_factors: Optional[List[Dict[str, Any]]] |
| 214 | + |
| 215 | + def __init__( |
| 216 | + self, |
| 217 | + *args: Any, |
| 218 | + user: Optional[Dict[str, Any]] = None, |
| 219 | + authentication_factors: Optional[List[Dict[str, Any]]] = None, |
| 220 | + **kwargs: Any, |
| 221 | + ) -> None: |
| 222 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 223 | + if user is None and response_json is not None: |
| 224 | + user = cast(Optional[Dict[str, Any]], response_json.get("user")) |
| 225 | + if authentication_factors is None and response_json is not None: |
| 226 | + authentication_factors = cast( |
| 227 | + Optional[List[Dict[str, Any]]], |
| 228 | + response_json.get("authentication_factors"), |
| 229 | + ) |
| 230 | + super().__init__(*args, **kwargs) |
| 231 | + self.user = user |
| 232 | + self.authentication_factors = authentication_factors |
| 233 | + |
| 234 | + |
| 235 | +class OrganizationSelectionRequiredError(AuthenticationFlowError): |
| 236 | + """Raised when the user must select an organization.""" |
| 237 | + |
| 238 | + user: Optional[Dict[str, Any]] |
| 239 | + organizations: Optional[List[Dict[str, Any]]] |
| 240 | + |
| 241 | + def __init__( |
| 242 | + self, |
| 243 | + *args: Any, |
| 244 | + user: Optional[Dict[str, Any]] = None, |
| 245 | + organizations: Optional[List[Dict[str, Any]]] = None, |
| 246 | + **kwargs: Any, |
| 247 | + ) -> None: |
| 248 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 249 | + if user is None and response_json is not None: |
| 250 | + user = cast(Optional[Dict[str, Any]], response_json.get("user")) |
| 251 | + if organizations is None and response_json is not None: |
| 252 | + organizations = cast( |
| 253 | + Optional[List[Dict[str, Any]]], response_json.get("organizations") |
| 254 | + ) |
| 255 | + super().__init__(*args, **kwargs) |
| 256 | + self.user = user |
| 257 | + self.organizations = organizations |
| 258 | + |
| 259 | + |
| 260 | +class SsoRequiredError(AuthenticationFlowError): |
| 261 | + """Raised when SSO authentication is required.""" |
| 262 | + |
| 263 | + email: Optional[str] |
| 264 | + connection_ids: Optional[List[str]] |
| 265 | + |
| 266 | + def __init__( |
| 267 | + self, |
| 268 | + *args: Any, |
| 269 | + email: Optional[str] = None, |
| 270 | + connection_ids: Optional[List[str]] = None, |
| 271 | + **kwargs: Any, |
| 272 | + ) -> None: |
| 273 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 274 | + if email is None and response_json is not None: |
| 275 | + email = cast(Optional[str], response_json.get("email")) |
| 276 | + if connection_ids is None and response_json is not None: |
| 277 | + connection_ids = cast( |
| 278 | + Optional[List[str]], response_json.get("connection_ids") |
| 279 | + ) |
| 280 | + super().__init__(*args, **kwargs) |
| 281 | + self.email = email |
| 282 | + self.connection_ids = connection_ids |
| 283 | + |
| 284 | + |
| 285 | +class OrganizationAuthMethodsRequiredError(AuthenticationFlowError): |
| 286 | + """Raised when organization-specific authentication methods are required.""" |
| 287 | + |
| 288 | + email: Optional[str] |
| 289 | + sso_connection_ids: Optional[List[str]] |
| 290 | + auth_methods: Optional[Dict[str, bool]] |
| 291 | + |
| 292 | + def __init__( |
| 293 | + self, |
| 294 | + *args: Any, |
| 295 | + email: Optional[str] = None, |
| 296 | + sso_connection_ids: Optional[List[str]] = None, |
| 297 | + auth_methods: Optional[Dict[str, bool]] = None, |
| 298 | + **kwargs: Any, |
| 299 | + ) -> None: |
| 300 | + response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json")) |
| 301 | + if email is None and response_json is not None: |
| 302 | + email = cast(Optional[str], response_json.get("email")) |
| 303 | + if sso_connection_ids is None and response_json is not None: |
| 304 | + sso_connection_ids = cast( |
| 305 | + Optional[List[str]], response_json.get("sso_connection_ids") |
| 306 | + ) |
| 307 | + if auth_methods is None and response_json is not None: |
| 308 | + auth_methods = cast( |
| 309 | + Optional[Dict[str, bool]], response_json.get("auth_methods") |
| 310 | + ) |
| 311 | + super().__init__(*args, **kwargs) |
| 312 | + self.email = email |
| 313 | + self.sso_connection_ids = sso_connection_ids |
| 314 | + self.auth_methods = auth_methods |
| 315 | + |
| 316 | + |
| 317 | +class AuthenticationMethodNotAllowedError(AuthenticationFlowError): |
| 318 | + """Raised when the authentication method is not allowed.""" |
| 319 | + |
| 320 | + |
| 321 | +class EmailPasswordAuthDisabledError(AuthenticationFlowError): |
| 322 | + """Raised when email/password authentication is disabled.""" |
| 323 | + |
| 324 | + |
| 325 | +class PasskeyProgressiveEnrollmentError(AuthenticationFlowError): |
| 326 | + """Raised when passkey progressive enrollment is required.""" |
| 327 | + |
| 328 | + |
| 329 | +class RadarChallengeError(AuthenticationFlowError): |
| 330 | + """Raised when a Radar challenge is required.""" |
| 331 | + |
| 332 | + |
| 333 | +class RadarSignUpChallengeError(AuthenticationFlowError): |
| 334 | + """Raised when a Radar sign-up challenge is required.""" |
160 | 335 |
|
161 | 336 |
|
162 | 337 | class NotFoundError(APIError): |
@@ -234,3 +409,19 @@ def __init__(self, message: str = "Request timed out") -> None: |
234 | 409 | 422: UnprocessableEntityError, |
235 | 410 | 429: RateLimitExceededError, |
236 | 411 | } |
| 412 | + |
| 413 | +# Maps authentication error code/error values to specific error classes. |
| 414 | +# Checked by _raise_error() for 403 responses before falling through to AuthorizationError. |
| 415 | +_AUTH_CODE_TO_ERROR: Dict[str, Type[AuthenticationFlowError]] = { |
| 416 | + "email_verification_required": EmailVerificationRequiredError, |
| 417 | + "mfa_enrollment": MfaEnrollmentError, |
| 418 | + "mfa_challenge": MfaChallengeError, |
| 419 | + "organization_selection_required": OrganizationSelectionRequiredError, |
| 420 | + "sso_required": SsoRequiredError, |
| 421 | + "organization_authentication_methods_required": OrganizationAuthMethodsRequiredError, |
| 422 | + "authentication_method_not_allowed": AuthenticationMethodNotAllowedError, |
| 423 | + "email_password_auth_disabled": EmailPasswordAuthDisabledError, |
| 424 | + "passkey_progressive_enrollment": PasskeyProgressiveEnrollmentError, |
| 425 | + "radar_challenge": RadarChallengeError, |
| 426 | + "radar_sign_up_challenge": RadarSignUpChallengeError, |
| 427 | +} |
0 commit comments