-
Notifications
You must be signed in to change notification settings - Fork 280
Expand file tree
/
Copy pathdependencies.py
More file actions
106 lines (83 loc) · 4.02 KB
/
dependencies.py
File metadata and controls
106 lines (83 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Annotated, Any
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.config import settings
from ..core.db.database import async_get_db
from ..core.exceptions.http_exceptions import ForbiddenException, RateLimitException, UnauthorizedException
from ..core.logger import logging
from ..core.security import TokenType, oauth2_scheme, verify_token
from ..core.utils.rate_limit import rate_limiter
from ..crud.crud_rate_limit import crud_rate_limits
from ..crud.crud_tier import crud_tiers
from ..crud.crud_users import crud_users
from ..schemas.rate_limit import RateLimitRead, sanitize_path
from ..schemas.tier import TierRead
logger = logging.getLogger(__name__)
DEFAULT_LIMIT = settings.DEFAULT_RATE_LIMIT_LIMIT
DEFAULT_PERIOD = settings.DEFAULT_RATE_LIMIT_PERIOD
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[AsyncSession, Depends(async_get_db)]
) -> dict[str, Any]:
token_data = await verify_token(token, TokenType.ACCESS, db)
if token_data is None:
raise UnauthorizedException("User not authenticated.")
if "@" in token_data.username_or_email:
user = await crud_users.get(db=db, email=token_data.username_or_email, is_deleted=False)
else:
user = await crud_users.get(db=db, username=token_data.username_or_email, is_deleted=False)
if user:
return user
raise UnauthorizedException("User not authenticated.")
async def get_optional_user(request: Request, db: AsyncSession = Depends(async_get_db)) -> dict | None:
token = request.headers.get("Authorization")
if not token:
return None
try:
token_type, _, token_value = token.partition(" ")
if token_type.lower() != "bearer" or not token_value:
return None
token_data = await verify_token(token_value, TokenType.ACCESS, db)
if token_data is None:
return None
return await get_current_user(token_value, db=db)
except HTTPException as http_exc:
if http_exc.status_code != 401:
logger.error(f"Unexpected HTTPException in get_optional_user: {http_exc.detail}")
return None
except Exception as exc:
logger.error(f"Unexpected error in get_optional_user: {exc}")
return None
async def get_current_superuser(current_user: Annotated[dict, Depends(get_current_user)]) -> dict:
if not current_user["is_superuser"]:
raise ForbiddenException("You do not have enough privileges.")
return current_user
async def rate_limiter_dependency(
request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], user: dict | None = Depends(get_optional_user)
) -> None:
if hasattr(request.app.state, "initialization_complete"):
await request.app.state.initialization_complete.wait()
path = sanitize_path(request.url.path)
if user:
user_id = user["id"]
tier = await crud_tiers.get(db, id=user["tier_id"], schema_to_select=TierRead)
if tier:
rate_limit = await crud_rate_limits.get(
db=db, tier_id=tier["id"], path=path, schema_to_select=RateLimitRead
)
if rate_limit:
limit, period = rate_limit["limit"], rate_limit["period"]
else:
logger.warning(
f"User {user_id} with tier '{tier['name']}' has no specific rate limit for path '{path}'. \
Applying default rate limit."
)
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
else:
logger.warning(f"User {user_id} has no assigned tier. Applying default rate limit.")
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
else:
user_id = request.client.host if request.client else "unknown"
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
is_limited = await rate_limiter.is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
if is_limited:
raise RateLimitException("Rate limit exceeded.")