Skip to content

Commit 2f98e9c

Browse files
committed
chore: api authentication
1 parent d5e74b3 commit 2f98e9c

File tree

1 file changed

+68
-0
lines changed
  • src/quant_research_starter/api

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""JWT authentication utilities and dependencies."""
2+
from __future__ import annotations
3+
import os
4+
from datetime import datetime, timedelta
5+
from typing import Optional
6+
7+
from passlib.context import CryptContext
8+
from jose import JWTError, jwt
9+
from fastapi import Depends, HTTPException, status
10+
from fastapi.security import OAuth2PasswordBearer
11+
from sqlalchemy.ext.asyncio import AsyncSession
12+
13+
from . import models, db
14+
15+
SECRET_KEY = os.getenv("JWT_SECRET", "dev-secret-change-me")
16+
ALGORITHM = "HS256"
17+
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "60"))
18+
19+
pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto")
20+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
21+
22+
23+
def verify_password(plain_password: str, hashed_password: str) -> bool:
24+
return pwd_ctx.verify(plain_password, hashed_password)
25+
26+
27+
def get_password_hash(password: str) -> str:
28+
return pwd_ctx.hash(password)
29+
30+
31+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
32+
to_encode = data.copy()
33+
if expires_delta:
34+
expire = datetime.utcnow() + expires_delta
35+
else:
36+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
37+
to_encode.update({"exp": expire})
38+
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
39+
40+
41+
async def get_current_user(token: str = Depends(oauth2_scheme), session: AsyncSession = Depends(db.get_session)):
42+
credentials_exception = HTTPException(
43+
status_code=status.HTTP_401_UNAUTHORIZED,
44+
detail="Could not validate credentials",
45+
headers={"WWW-Authenticate": "Bearer"},
46+
)
47+
try:
48+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
49+
username: str = payload.get("sub")
50+
if username is None:
51+
raise credentials_exception
52+
except JWTError:
53+
raise credentials_exception
54+
55+
q = await session.execute(
56+
models.User.__table__.select().where(models.User.username == username)
57+
)
58+
row = q.first()
59+
if not row:
60+
raise credentials_exception
61+
user = row[0]
62+
return user
63+
64+
65+
async def require_active_user(current_user=Depends(get_current_user)):
66+
if not current_user.is_active:
67+
raise HTTPException(status_code=400, detail="Inactive user")
68+
return current_user

0 commit comments

Comments
 (0)