-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Expand file tree
/
Copy pathwebsocket.py
More file actions
81 lines (69 loc) · 2.28 KB
/
websocket.py
File metadata and controls
81 lines (69 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
from typing import Annotated
import jwt
from fastapi import (
APIRouter,
HTTPException,
Query,
WebSocket,
WebSocketDisconnect,
status,
)
from jwt.exceptions import InvalidTokenError
from pydantic import ValidationError
from sqlmodel import Session
from app.core import security
from app.core.config import settings
from app.core.db import engine
from app.core.websocket import manager
from app.models import TokenPayload, User
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
def get_user_from_token(token: str) -> User:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = TokenPayload(**payload)
except (InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
with Session(engine) as session:
user = session.get(User, token_data.sub)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return user
@router.websocket("/ws/notifications")
async def websocket_notifications(
websocket: WebSocket,
token: Annotated[str | None, Query()] = None,
) -> None:
"""
WebSocket endpoint for real-time notifications.
Clients should connect with a valid JWT token as a query parameter.
"""
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
user = get_user_from_token(token)
except HTTPException:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
user_id = user.id
await manager.connect(user_id, websocket)
try:
while True:
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
manager.disconnect(user_id, websocket)
logger.info(f"WebSocket disconnected for user: {user_id}")
except Exception as e:
manager.disconnect(user_id, websocket)
logger.error(f"WebSocket error for user {user_id}: {e}")