Skip to content

Commit 08797b6

Browse files
committed
expand error classes
1 parent 8917423 commit 08797b6

File tree

3 files changed

+517
-3
lines changed

3 files changed

+517
-3
lines changed

src/workos/_base_client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
WorkOSConnectionError,
2323
WorkOSTimeoutError,
2424
STATUS_CODE_TO_ERROR,
25+
_AUTH_CODE_TO_ERROR,
2526
)
2627
from ._pagination import AsyncPage, ListMetadata, SyncPage
2728
from ._types import D, Deserializable, RequestOptions
@@ -245,6 +246,28 @@ def _raise_error(response: httpx.Response) -> None:
245246
request_url=request_url,
246247
request_method=request_method,
247248
)
249+
# Auth-flow dispatch for 403 responses: check code/error field
250+
# for specific auth-flow errors before falling through to generic class.
251+
if response.status_code == 403 and response_json is not None:
252+
auth_code = code or error
253+
if auth_code is not None:
254+
auth_error_class = _AUTH_CODE_TO_ERROR.get(auth_code)
255+
if auth_error_class is not None:
256+
raise auth_error_class(
257+
message,
258+
request_id=request_id,
259+
code=code,
260+
param=param,
261+
response=response,
262+
response_json=response_json,
263+
error=error,
264+
errors=errors,
265+
error_description=error_description,
266+
raw_body=raw_body,
267+
request_url=request_url,
268+
request_method=request_method,
269+
)
270+
248271
raise error_class(
249272
message,
250273
request_id=request_id,

src/workos/_errors.py

Lines changed: 194 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Dict, Mapping, Optional, Type, cast
5+
from typing import Any, Dict, List, Mapping, Optional, Type, cast
66

77

88
class WorkOSError(Exception):
@@ -142,21 +142,196 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
142142
super().__init__(*args, **kwargs)
143143

144144

145-
class EmailVerificationRequiredError(AuthorizationError):
145+
class AuthenticationFlowError(AuthorizationError):
146+
"""Raised when authentication requires an additional step.
147+
148+
All auth-flow 403 errors carry a pending_authentication_token that
149+
must be passed to the next step in the authentication flow.
150+
"""
151+
152+
pending_authentication_token: Optional[str]
153+
154+
def __init__(
155+
self,
156+
*args: Any,
157+
pending_authentication_token: Optional[str] = None,
158+
**kwargs: Any,
159+
) -> None:
160+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
161+
if pending_authentication_token is None and response_json is not None:
162+
pending_authentication_token = cast(
163+
Optional[str], response_json.get("pending_authentication_token")
164+
)
165+
super().__init__(*args, **kwargs)
166+
self.pending_authentication_token = pending_authentication_token
167+
168+
169+
class EmailVerificationRequiredError(AuthenticationFlowError):
146170
"""Raised when email verification is required before authentication."""
147171

148172
email_verification_id: Optional[str]
173+
email: Optional[str]
149174

150175
def __init__(
151-
self, *args: Any, email_verification_id: Optional[str] = None, **kwargs: Any
176+
self,
177+
*args: Any,
178+
email_verification_id: Optional[str] = None,
179+
email: Optional[str] = None,
180+
**kwargs: Any,
152181
) -> None:
153182
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
154183
if email_verification_id is None and response_json is not None:
155184
email_verification_id = cast(
156185
Optional[str], response_json.get("email_verification_id")
157186
)
187+
if email is None and response_json is not None:
188+
email = cast(Optional[str], response_json.get("email"))
158189
super().__init__(*args, **kwargs)
159190
self.email_verification_id = email_verification_id
191+
self.email = email
192+
193+
194+
class MfaEnrollmentError(AuthenticationFlowError):
195+
"""Raised when MFA enrollment is required."""
196+
197+
user: Optional[Dict[str, Any]]
198+
199+
def __init__(
200+
self, *args: Any, user: Optional[Dict[str, Any]] = None, **kwargs: Any
201+
) -> None:
202+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
203+
if user is None and response_json is not None:
204+
user = cast(Optional[Dict[str, Any]], response_json.get("user"))
205+
super().__init__(*args, **kwargs)
206+
self.user = user
207+
208+
209+
class MfaChallengeError(AuthenticationFlowError):
210+
"""Raised when an MFA challenge must be completed."""
211+
212+
user: Optional[Dict[str, Any]]
213+
authentication_factors: Optional[List[Dict[str, Any]]]
214+
215+
def __init__(
216+
self,
217+
*args: Any,
218+
user: Optional[Dict[str, Any]] = None,
219+
authentication_factors: Optional[List[Dict[str, Any]]] = None,
220+
**kwargs: Any,
221+
) -> None:
222+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
223+
if user is None and response_json is not None:
224+
user = cast(Optional[Dict[str, Any]], response_json.get("user"))
225+
if authentication_factors is None and response_json is not None:
226+
authentication_factors = cast(
227+
Optional[List[Dict[str, Any]]],
228+
response_json.get("authentication_factors"),
229+
)
230+
super().__init__(*args, **kwargs)
231+
self.user = user
232+
self.authentication_factors = authentication_factors
233+
234+
235+
class OrganizationSelectionRequiredError(AuthenticationFlowError):
236+
"""Raised when the user must select an organization."""
237+
238+
user: Optional[Dict[str, Any]]
239+
organizations: Optional[List[Dict[str, Any]]]
240+
241+
def __init__(
242+
self,
243+
*args: Any,
244+
user: Optional[Dict[str, Any]] = None,
245+
organizations: Optional[List[Dict[str, Any]]] = None,
246+
**kwargs: Any,
247+
) -> None:
248+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
249+
if user is None and response_json is not None:
250+
user = cast(Optional[Dict[str, Any]], response_json.get("user"))
251+
if organizations is None and response_json is not None:
252+
organizations = cast(
253+
Optional[List[Dict[str, Any]]], response_json.get("organizations")
254+
)
255+
super().__init__(*args, **kwargs)
256+
self.user = user
257+
self.organizations = organizations
258+
259+
260+
class SsoRequiredError(AuthenticationFlowError):
261+
"""Raised when SSO authentication is required."""
262+
263+
email: Optional[str]
264+
connection_ids: Optional[List[str]]
265+
266+
def __init__(
267+
self,
268+
*args: Any,
269+
email: Optional[str] = None,
270+
connection_ids: Optional[List[str]] = None,
271+
**kwargs: Any,
272+
) -> None:
273+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
274+
if email is None and response_json is not None:
275+
email = cast(Optional[str], response_json.get("email"))
276+
if connection_ids is None and response_json is not None:
277+
connection_ids = cast(
278+
Optional[List[str]], response_json.get("connection_ids")
279+
)
280+
super().__init__(*args, **kwargs)
281+
self.email = email
282+
self.connection_ids = connection_ids
283+
284+
285+
class OrganizationAuthMethodsRequiredError(AuthenticationFlowError):
286+
"""Raised when organization-specific authentication methods are required."""
287+
288+
email: Optional[str]
289+
sso_connection_ids: Optional[List[str]]
290+
auth_methods: Optional[Dict[str, bool]]
291+
292+
def __init__(
293+
self,
294+
*args: Any,
295+
email: Optional[str] = None,
296+
sso_connection_ids: Optional[List[str]] = None,
297+
auth_methods: Optional[Dict[str, bool]] = None,
298+
**kwargs: Any,
299+
) -> None:
300+
response_json = cast(Optional[Mapping[str, Any]], kwargs.get("response_json"))
301+
if email is None and response_json is not None:
302+
email = cast(Optional[str], response_json.get("email"))
303+
if sso_connection_ids is None and response_json is not None:
304+
sso_connection_ids = cast(
305+
Optional[List[str]], response_json.get("sso_connection_ids")
306+
)
307+
if auth_methods is None and response_json is not None:
308+
auth_methods = cast(
309+
Optional[Dict[str, bool]], response_json.get("auth_methods")
310+
)
311+
super().__init__(*args, **kwargs)
312+
self.email = email
313+
self.sso_connection_ids = sso_connection_ids
314+
self.auth_methods = auth_methods
315+
316+
317+
class AuthenticationMethodNotAllowedError(AuthenticationFlowError):
318+
"""Raised when the authentication method is not allowed."""
319+
320+
321+
class EmailPasswordAuthDisabledError(AuthenticationFlowError):
322+
"""Raised when email/password authentication is disabled."""
323+
324+
325+
class PasskeyProgressiveEnrollmentError(AuthenticationFlowError):
326+
"""Raised when passkey progressive enrollment is required."""
327+
328+
329+
class RadarChallengeError(AuthenticationFlowError):
330+
"""Raised when a Radar challenge is required."""
331+
332+
333+
class RadarSignUpChallengeError(AuthenticationFlowError):
334+
"""Raised when a Radar sign-up challenge is required."""
160335

161336

162337
class NotFoundError(APIError):
@@ -234,3 +409,19 @@ def __init__(self, message: str = "Request timed out") -> None:
234409
422: UnprocessableEntityError,
235410
429: RateLimitExceededError,
236411
}
412+
413+
# Maps authentication error code/error values to specific error classes.
414+
# Checked by _raise_error() for 403 responses before falling through to AuthorizationError.
415+
_AUTH_CODE_TO_ERROR: Dict[str, Type[AuthenticationFlowError]] = {
416+
"email_verification_required": EmailVerificationRequiredError,
417+
"mfa_enrollment": MfaEnrollmentError,
418+
"mfa_challenge": MfaChallengeError,
419+
"organization_selection_required": OrganizationSelectionRequiredError,
420+
"sso_required": SsoRequiredError,
421+
"organization_authentication_methods_required": OrganizationAuthMethodsRequiredError,
422+
"authentication_method_not_allowed": AuthenticationMethodNotAllowedError,
423+
"email_password_auth_disabled": EmailPasswordAuthDisabledError,
424+
"passkey_progressive_enrollment": PasskeyProgressiveEnrollmentError,
425+
"radar_challenge": RadarChallengeError,
426+
"radar_sign_up_challenge": RadarSignUpChallengeError,
427+
}

0 commit comments

Comments
 (0)