|
45 | 45 | get_account_from_recovery_token, |
46 | 46 | get_account_from_credentials, |
47 | 47 | require_unauthenticated_client, |
| 48 | + require_unauthenticated_unless_invitation_warning, |
48 | 49 | get_verified_account, |
49 | 50 | ) |
50 | 51 | from exceptions.http_exceptions import ( |
|
54 | 55 | EmailNotVerifiedError, |
55 | 56 | MaxEmailsReachedError, |
56 | 57 | PasswordValidationError, |
57 | | - InvalidInvitationTokenError, |
58 | 58 | InvitationEmailMismatchError, |
59 | 59 | InvitationProcessingError, |
60 | 60 | ) |
61 | 61 | from routers.core.dashboard import router as dashboard_router |
62 | 62 | from routers.core.user import router as user_router |
63 | 63 | from routers.core.organization import router as org_router |
64 | | -from utils.core.invitations import process_invitation |
| 64 | +from utils.core.invitations import ( |
| 65 | + process_invitation, |
| 66 | + require_active_invitation_by_token, |
| 67 | + get_invitation_token_warning, |
| 68 | +) |
65 | 69 | from utils.core.rate_limit import ( |
66 | 70 | check_login_ip_rate_limit, |
67 | 71 | check_login_email_rate_limit, |
@@ -147,37 +151,56 @@ def logout( |
147 | 151 | @router.get("/login") |
148 | 152 | async def read_login( |
149 | 153 | request: Request, |
150 | | - _: None = Depends(require_unauthenticated_client), |
| 154 | + _: None = Depends(require_unauthenticated_unless_invitation_warning), |
151 | 155 | invitation_token: Optional[str] = Query(None), |
| 156 | + user: Optional[User] = Depends(get_optional_user), |
| 157 | + session: Session = Depends(get_session), |
152 | 158 | ): |
153 | 159 | """ |
154 | 160 | Render login page or redirect to dashboard if already logged in. |
155 | 161 | """ |
| 162 | + invitation_token_warning = ( |
| 163 | + get_invitation_token_warning(session, invitation_token) |
| 164 | + if invitation_token |
| 165 | + else None |
| 166 | + ) |
156 | 167 | return templates.TemplateResponse( |
157 | 168 | request, |
158 | 169 | "account/login.html", |
159 | | - {"user": None, "invitation_token": invitation_token}, |
| 170 | + { |
| 171 | + "user": user, |
| 172 | + "invitation_token": invitation_token, |
| 173 | + "invitation_token_warning": invitation_token_warning, |
| 174 | + }, |
160 | 175 | ) |
161 | 176 |
|
162 | 177 |
|
163 | 178 | @router.get("/register") |
164 | 179 | async def read_register( |
165 | 180 | request: Request, |
166 | | - _: None = Depends(require_unauthenticated_client), |
| 181 | + _: None = Depends(require_unauthenticated_unless_invitation_warning), |
167 | 182 | email: Optional[EmailStr] = Query(None), |
168 | 183 | invitation_token: Optional[str] = Query(None), |
| 184 | + user: Optional[User] = Depends(get_optional_user), |
| 185 | + session: Session = Depends(get_session), |
169 | 186 | ): |
170 | 187 | """ |
171 | 188 | Render registration page or redirect to dashboard if already logged in. |
172 | 189 | """ |
| 190 | + invitation_token_warning = ( |
| 191 | + get_invitation_token_warning(session, invitation_token) |
| 192 | + if invitation_token |
| 193 | + else None |
| 194 | + ) |
173 | 195 | return templates.TemplateResponse( |
174 | 196 | request, |
175 | 197 | "account/register.html", |
176 | 198 | { |
177 | | - "user": None, |
| 199 | + "user": user, |
178 | 200 | "password_pattern": HTML_PASSWORD_PATTERN, |
179 | 201 | "email": email, |
180 | 202 | "invitation_token": invitation_token, |
| 203 | + "invitation_token_warning": invitation_token_warning, |
181 | 204 | }, |
182 | 205 | ) |
183 | 206 |
|
@@ -270,6 +293,18 @@ async def register( |
270 | 293 | """ |
271 | 294 | Register a new user account, optionally processing an invitation. |
272 | 295 | """ |
| 296 | + pending_invitation: Optional[Invitation] = None |
| 297 | + if invitation_token: |
| 298 | + pending_invitation = require_active_invitation_by_token( |
| 299 | + session, invitation_token |
| 300 | + ) |
| 301 | + if email != pending_invitation.invitee_email: |
| 302 | + logger.warning( |
| 303 | + f"Invitation email mismatch for token {invitation_token} during registration. " |
| 304 | + f"Account: {email}, Invitation: {pending_invitation.invitee_email}" |
| 305 | + ) |
| 306 | + raise InvitationEmailMismatchError() |
| 307 | + |
273 | 308 | # Check if the email is already registered |
274 | 309 | existing_account: Optional[Account] = session.exec( |
275 | 310 | select(Account).where(Account.email == email) |
@@ -313,46 +348,27 @@ async def register( |
313 | 348 | redirect_url = dashboard_router.url_path_for("read_dashboard") |
314 | 349 |
|
315 | 350 | # Process invitation if token is provided (BEFORE final commit) |
316 | | - if invitation_token: |
| 351 | + if pending_invitation: |
317 | 352 | logger.info( |
318 | 353 | f"Registration attempt with invitation token: {invitation_token} for email {email}" |
319 | 354 | ) |
320 | | - # Fetch the invitation |
321 | | - statement = select(Invitation).where(Invitation.token == invitation_token) |
322 | | - invitation = session.exec(statement).first() |
323 | | - |
324 | | - if not invitation or not invitation.is_active(): |
325 | | - logger.warning( |
326 | | - f"Invalid or inactive invitation token provided during registration: {invitation_token}" |
327 | | - ) |
328 | | - # Consider raising a more generic error to avoid exposing token validity |
329 | | - raise InvalidInvitationTokenError() |
330 | | - |
331 | | - # Verify email matches |
332 | | - if email != invitation.invitee_email: |
333 | | - logger.warning( |
334 | | - f"Invitation email mismatch for token {invitation_token} during registration. " |
335 | | - f"Account: {email}, Invitation: {invitation.invitee_email}" |
336 | | - ) |
337 | | - # Consider raising a more generic error to avoid confirming email existence |
338 | | - raise InvitationEmailMismatchError() |
339 | 355 |
|
340 | 356 | # Process the invitation (adds changes to the session) |
341 | 357 | try: |
342 | 358 | logger.info( |
343 | | - f"Processing invitation {invitation.id} for new user {new_user.name} ({email}) during registration." |
| 359 | + f"Processing invitation {pending_invitation.id} for new user {new_user.name} ({email}) during registration." |
344 | 360 | ) |
345 | | - process_invitation(invitation, new_user, session) |
| 361 | + process_invitation(pending_invitation, new_user, session) |
346 | 362 | # Set redirect to the organization page |
347 | 363 | redirect_url = org_router.url_path_for( |
348 | | - "read_organization", org_id=invitation.organization_id |
| 364 | + "read_organization", org_id=pending_invitation.organization_id |
349 | 365 | ) |
350 | 366 | logger.info( |
351 | | - f"Redirecting new user {new_user.name} to organization {invitation.organization_id} after accepting invitation {invitation.id}." |
| 367 | + f"Redirecting new user {new_user.name} to organization {pending_invitation.organization_id} after accepting invitation {pending_invitation.id}." |
352 | 368 | ) |
353 | 369 | except Exception as e: |
354 | 370 | logger.error( |
355 | | - f"Error processing invitation {invitation.id} for new user {new_user.name} ({email}) during registration: {e}", |
| 371 | + f"Error processing invitation {pending_invitation.id} for new user {new_user.name} ({email}) during registration: {e}", |
356 | 372 | exc_info=True, |
357 | 373 | ) |
358 | 374 | session.rollback() |
@@ -434,15 +450,7 @@ async def login( |
434 | 450 | logger.info( |
435 | 451 | f"Login attempt with invitation token: {invitation_token} for account {account.email}" |
436 | 452 | ) |
437 | | - # Fetch the invitation |
438 | | - statement = select(Invitation).where(Invitation.token == invitation_token) |
439 | | - invitation = session.exec(statement).first() |
440 | | - |
441 | | - if not invitation or not invitation.is_active(): |
442 | | - logger.warning( |
443 | | - f"Invalid or inactive invitation token provided during login: {invitation_token}" |
444 | | - ) |
445 | | - raise InvalidInvitationTokenError() |
| 453 | + invitation = require_active_invitation_by_token(session, invitation_token) |
446 | 454 |
|
447 | 455 | # Verify email matches (check primary and any verified secondary emails) |
448 | 456 | account_emails = session.exec( |
|
0 commit comments