Skip to content

Commit b81e527

Browse files
committed
add extra endpoints and develop security with adding permissions to endpoints
1 parent cdcd355 commit b81e527

13 files changed

Lines changed: 319 additions & 338 deletions

File tree

src/api/v1/user.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from fastapi import APIRouter, Depends
22

3+
from db.dependencies.auth import get_current_user, require_roles
4+
from db.models.users import User, UserRole
35
from schemas.auth import (
46
LoginOutScheme,
57
LoginScheme,
@@ -12,6 +14,26 @@
1214
router = APIRouter()
1315

1416

17+
@router.get("/me")
18+
async def get_me(current_user: User = Depends(get_current_user)):
19+
return {
20+
"id": current_user.id,
21+
"email": current_user.email,
22+
"username": current_user.username,
23+
"role": current_user.role,
24+
}
25+
26+
27+
@router.get("/dashboard")
28+
async def dashboard(current_user: User = Depends(get_current_user)):
29+
return {"message": f"Welcome, {current_user.username}"}
30+
31+
32+
@router.get("/admin/stats")
33+
async def admin_stats(admin_user: User = Depends(require_roles(UserRole.admin.value))):
34+
return {"status": "ok", "admin": admin_user.username}
35+
36+
1537
@router.post("/register", response_model=UserCreateScheme)
1638
async def register_user(
1739
data: RegistrationScheme,

src/core/security.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,109 @@
1+
import uuid
12
from datetime import datetime, timedelta
3+
from typing import Any, Dict, Optional, Union
24

3-
from jose import jwt
5+
from jose import ExpiredSignatureError, JWTError, jwt
46
from passlib.context import CryptContext
57

68
from core.config import settings
79

810
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
911

1012

13+
# Hashing helpers
1114
def hash_password(password: str) -> str:
1215
return pwd_context.hash(password)
1316

1417

18+
# Verifying helpers
1519
def verify_password(plain_password: str, hashed_password: str) -> bool:
1620
return pwd_context.verify(plain_password, hashed_password)
1721

1822

19-
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
20-
to_encode = data.copy()
23+
# helpers
24+
def _now() -> datetime:
25+
return datetime.utcnow()
26+
27+
28+
def _jti() -> str:
29+
return str(uuid.uuid4())
30+
31+
32+
# Token factories
33+
def create_access_token(
34+
subject: Union[str, int, Dict[str, Any]],
35+
extra: Optional[Dict[str, Any]] = None,
36+
expires_delta: timedelta | None = None,
37+
) -> str:
38+
"""
39+
Create access token.
40+
"""
41+
42+
if isinstance(subject, dict):
43+
payload_access: Dict[str, Any] = subject.copy()
44+
else:
45+
payload_access: Dict[str, Any] = {"sub": str(subject)} # type: ignore
46+
47+
payload_access.setdefault("type", "access")
48+
payload_access.setdefault("jti", _jti())
49+
payload_access.setdefault("iat", int(_now().timestamp()))
50+
51+
if extra:
52+
payload_access.update(extra)
53+
2154
if expires_delta:
22-
expire = datetime.utcnow() + expires_delta
55+
exp = _now() + expires_delta
2356
else:
24-
expire = datetime.utcnow() + timedelta(
25-
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
26-
)
27-
to_encode.update({"exp": expire})
28-
encoded_jwt = jwt.encode(
29-
to_encode, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM
57+
exp = _now() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES or 15)
58+
59+
payload_access["exp"] = int(exp.timestamp())
60+
61+
return jwt.encode(
62+
payload_access, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM
3063
)
31-
return encoded_jwt
3264

3365

34-
def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> str:
35-
to_encode = data.copy()
66+
def create_refresh_token(
67+
subject: Union[str, int, Dict[str, Any]],
68+
extra: Optional[Dict[str, Any]] = None,
69+
expires_delta: timedelta | None = None,
70+
) -> str:
71+
"""
72+
Create refresh token.
73+
"""
74+
75+
if isinstance(subject, dict):
76+
payload_refresh: Dict[str, Any] = subject.copy()
77+
else:
78+
payload_refresh: Dict[str, Any] = {"sub": str(subject)} # type: ignore
79+
80+
payload_refresh.setdefault("type", "refresh")
81+
payload_refresh.setdefault("jti", _jti())
82+
payload_refresh.setdefault("iat", int(_now().timestamp()))
83+
84+
if extra:
85+
payload_refresh.update(extra)
86+
3687
if expires_delta:
37-
expire = datetime.utcnow() + expires_delta
88+
exp = _now() + expires_delta
3889
else:
39-
expire = datetime.utcnow() + timedelta(
40-
days=settings.JWT_REFRESH_TOKEN_EXPIRES_DAYS
41-
)
42-
to_encode.update({"exp": expire})
43-
encoded_jwt = jwt.encode(
44-
to_encode, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM
90+
exp = _now() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRES_DAYS or 30)
91+
92+
payload_refresh["exp"] = int(exp.timestamp())
93+
94+
return jwt.encode(
95+
payload_refresh, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM
4596
)
46-
return encoded_jwt
97+
98+
99+
def decode_token(token: str) -> dict:
100+
try:
101+
payload = jwt.decode(
102+
token, settings.SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
103+
)
104+
return payload
105+
except ExpiredSignatureError:
106+
raise ExpiredSignatureError("Token has expired")
107+
108+
except JWTError:
109+
raise JWTError("Invalid token")

src/db/crud/category.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlalchemy import select
55
from sqlalchemy.ext.asyncio import AsyncSession
66

7-
from src.db.dependencies import get_db_session
7+
from src.db.dependencies.sessions import get_db_session
88
from src.db.models.category import Category
99
from src.schemas.category import CategoryCreateScheme, CategoryUpdateScheme
1010

src/db/crud/token.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from datetime import datetime
2+
from typing import Optional
3+
4+
from sqlalchemy import delete, select
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
7+
from db.models.revoked_token import RevokedToken
8+
9+
10+
class TokenCRUD:
11+
def __init__(self, sessino: AsyncSession) -> None:
12+
self.sessino = sessino
13+
14+
async def add(
15+
self,
16+
jti: str,
17+
token_type: str,
18+
expires_at: datetime,
19+
user_id: Optional[int] = None,
20+
) -> None:
21+
revoked = RevokedToken(
22+
jti=jti, token_type=token_type, expires_at=expires_at, user_id=user_id
23+
)
24+
self.sessino.add(revoked)
25+
await self.sessino.commit()
26+
27+
async def exists(self, jti: str) -> bool:
28+
stmt = await self.sessino.execute(
29+
select(RevokedToken).where(RevokedToken.jti == jti)
30+
)
31+
return stmt.scalar_one_or_none() is not None
32+
33+
async def cleanup_expired(self) -> int:
34+
now = datetime.utcnow()
35+
stmt = delete(RevokedToken).where(RevokedToken.expires_at < now)
36+
res = await self.sessino.execute(stmt)
37+
await self.sessino.commit()
38+
# The result.rowcount may be driver-dependent; we return 0/1+ best-effort
39+
return getattr(res, "rowcount", 0)

src/db/crud/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from sqlalchemy import select
33
from sqlalchemy.ext.asyncio import AsyncSession
44

5-
from db.dependencies import get_db_session
5+
from db.dependencies.sessions import get_db_session
66
from db.models.users import User
77
from schemas.auth import RegistrationScheme
88

src/db/dependencies/__init__.py

Whitespace-only changes.

src/db/dependencies/auth.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# src/dependencies/auth.py
2+
from fastapi import Depends, HTTPException, status
3+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
4+
from sqlalchemy.ext.asyncio import AsyncSession
5+
6+
from core.security import decode_token
7+
from db.crud.token import TokenCRUD
8+
from db.crud.user import UserCRUD
9+
from db.dependencies.sessions import get_db_session
10+
11+
http_bearer = HTTPBearer(auto_error=False)
12+
13+
14+
async def get_current_user(
15+
credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
16+
session: AsyncSession = Depends(get_db_session),
17+
):
18+
if credentials is None:
19+
raise HTTPException(
20+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
21+
)
22+
23+
token = credentials.credentials
24+
25+
try:
26+
payload = decode_token(token)
27+
except ValueError as exc:
28+
reason = str(exc)
29+
if reason == "token_expired":
30+
raise HTTPException(
31+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
32+
)
33+
raise HTTPException(
34+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
35+
)
36+
37+
if payload.get("type") != "access":
38+
raise HTTPException(
39+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is not access token"
40+
)
41+
42+
jti = payload.get("jti")
43+
token_crud = TokenCRUD(session)
44+
if await token_crud.exists(jti):
45+
raise HTTPException(
46+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token revoked"
47+
)
48+
49+
user_crud = UserCRUD(session)
50+
user = await user_crud.get_by_id(int(payload.get("sub")))
51+
if not user:
52+
raise HTTPException(
53+
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
54+
)
55+
56+
return user
57+
58+
59+
def require_roles(*roles: str):
60+
async def _inner(user=Depends(get_current_user)):
61+
if getattr(user, "role", None) not in roles:
62+
raise HTTPException(
63+
status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden"
64+
)
65+
return user
66+
67+
return _inner

src/db/dependencies/sessions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from collections.abc import AsyncGenerator
2+
3+
from sqlalchemy.ext.asyncio import AsyncSession
4+
from starlette.requests import Request
5+
6+
from core.database import get_async_session_maker
7+
8+
9+
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
10+
"""Dependency that provides a database session."""
11+
async_session_maker = get_async_session_maker()
12+
session = async_session_maker()
13+
14+
try:
15+
yield session
16+
await session.commit()
17+
except Exception:
18+
await session.rollback()
19+
raise
20+
finally:
21+
await session.close()

0 commit comments

Comments
 (0)