forked from benavlabs/FastAPI-boilerplate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdependencies.py
More file actions
106 lines (83 loc) · 4.08 KB
/
dependencies.py
File metadata and controls
106 lines (83 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Annotated, Any, cast
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.config import settings
from ..core.db.database import async_get_db
from ..core.exceptions.http_exceptions import ForbiddenException, RateLimitException, UnauthorizedException
from ..core.logger import logging
from ..core.security import TokenType, oauth2_scheme, verify_token
from ..core.utils.rate_limit import rate_limiter
from ..crud.crud_rate_limit import crud_rate_limits
from ..crud.crud_tier import crud_tiers
from ..crud.crud_users import crud_users
from ..schemas.rate_limit import RateLimitRead, sanitize_path
from ..schemas.tier import TierRead
logger = logging.getLogger(__name__)
DEFAULT_LIMIT = settings.DEFAULT_RATE_LIMIT_LIMIT
DEFAULT_PERIOD = settings.DEFAULT_RATE_LIMIT_PERIOD
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[AsyncSession, Depends(async_get_db)]
) -> dict[str, Any]:
token_data = await verify_token(token, TokenType.ACCESS, db)
if token_data is None:
raise UnauthorizedException("User not authenticated.")
if "@" in token_data.username_or_email:
user = await crud_users.get(db=db, email=token_data.username_or_email, is_deleted=False)
else:
user = await crud_users.get(db=db, username=token_data.username_or_email, is_deleted=False)
if user:
return user
raise UnauthorizedException("User not authenticated.")
async def get_optional_user(request: Request, db: AsyncSession = Depends(async_get_db)) -> dict | None:
token = request.headers.get("Authorization")
if not token:
return None
try:
token_type, _, token_value = token.partition(" ")
if token_type.lower() != "bearer" or not token_value:
return None
token_data = await verify_token(token_value, TokenType.ACCESS, db)
if token_data is None:
return None
return await get_current_user(token_value, db=db)
except HTTPException as http_exc:
if http_exc.status_code != 401:
logger.error(f"Unexpected HTTPException in get_optional_user: {http_exc.detail}")
return None
except Exception as exc:
logger.error(f"Unexpected error in get_optional_user: {exc}")
return None
async def get_current_superuser(current_user: Annotated[dict, Depends(get_current_user)]) -> dict:
if not current_user["is_superuser"]:
raise ForbiddenException("You do not have enough privileges.")
return current_user
async def rate_limiter_dependency(
request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], user: dict | None = Depends(get_optional_user)
) -> None:
if hasattr(request.app.state, "initialization_complete"):
await request.app.state.initialization_complete.wait()
path = sanitize_path(request.url.path)
if user:
user_id = user["id"]
tier = await crud_tiers.get(db, id=user["tier_id"], schema_to_select=TierRead)
if tier:
tier = cast(TierRead, tier)
rate_limit = await crud_rate_limits.get(db=db, tier_id=tier.id, path=path, schema_to_select=RateLimitRead)
if rate_limit:
rate_limit = cast(RateLimitRead, rate_limit)
limit, period = rate_limit.limit, rate_limit.period
else:
logger.warning(
f"User {user_id} with tier '{tier.name}' has no specific rate limit for path '{path}'. \
Applying default rate limit."
)
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
else:
logger.warning(f"User {user_id} has no assigned tier. Applying default rate limit.")
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
else:
user_id = request.client.host if request.client else "unknown"
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
is_limited = await rate_limiter.is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
if is_limited:
raise RateLimitException("Rate limit exceeded.")