Skip to content

Commit e0c9b0a

Browse files
committed
feat: 添加实时通知功能
- 后端实现通知模块,包括数据库表、API路由和WebSocket支持 - 前端添加通知中心组件,支持实时接收和显示通知 - 新增通知类型枚举,支持信息、成功、警告和错误四种类型 - 实现通知的创建、读取、更新和删除功能 - 添加WebSocket连接管理,实现实时通知推送 - 优化数据库配置,支持SQLite和PostgreSQL - 添加通知标记已读和全部标记已读功能
1 parent bba8d07 commit e0c9b0a

File tree

15 files changed

+1249
-37
lines changed

15 files changed

+1249
-37
lines changed

backend/app.db

32 KB
Binary file not shown.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Add notification table
2+
3+
Revision ID: add_notification_table
4+
Revises: 1a31ce608336
5+
Create Date: 2026-04-15 00:00:00.000000
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision = 'add_notification_table'
16+
down_revision = '1a31ce608336'
17+
branch_labels = None
18+
depends_on = None
19+
20+
21+
def upgrade():
22+
# Create notification table
23+
op.create_table(
24+
'notification',
25+
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
26+
sa.Column('message', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True),
27+
sa.Column('notification_type', sa.Enum('INFO', 'SUCCESS', 'WARNING', 'ERROR', name='notificationtype'), nullable=False),
28+
sa.Column('is_read', sa.Boolean(), nullable=False),
29+
sa.Column('action_url', sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
30+
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
31+
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
32+
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
33+
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
34+
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
35+
sa.PrimaryKeyConstraint('id')
36+
)
37+
# Create index on user_id for faster queries
38+
op.create_index(op.f('ix_notification_user_id'), 'notification', ['user_id'], unique=False)
39+
# Create index on created_at for ordering
40+
op.create_index(op.f('ix_notification_created_at'), 'notification', ['created_at'], unique=False)
41+
# Create index on is_read for filtering
42+
op.create_index(op.f('ix_notification_is_read'), 'notification', ['is_read'], unique=False)
43+
44+
45+
def downgrade():
46+
# Drop indexes first
47+
op.drop_index(op.f('ix_notification_is_read'), table_name='notification')
48+
op.drop_index(op.f('ix_notification_created_at'), table_name='notification')
49+
op.drop_index(op.f('ix_notification_user_id'), table_name='notification')
50+
# Drop table
51+
op.drop_table('notification')
52+
# Drop enum type
53+
op.execute('DROP TYPE IF EXISTS notificationtype')

backend/app/api/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from fastapi import APIRouter
22

3-
from app.api.routes import items, login, private, users, utils
3+
from app.api.routes import items, login, notifications, private, users, utils
44
from app.core.config import settings
55

66
api_router = APIRouter()
77
api_router.include_router(login.router)
88
api_router.include_router(users.router)
99
api_router.include_router(utils.router)
1010
api_router.include_router(items.router)
11+
api_router.include_router(notifications.router)
1112

1213

1314
if settings.ENVIRONMENT == "local":
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import uuid
2+
from typing import Any
3+
4+
from fastapi import APIRouter, HTTPException
5+
from sqlmodel import col, func, select
6+
7+
from app.api.deps import CurrentUser, SessionDep
8+
from app.core.websocket import manager
9+
from app.models import (
10+
Message,
11+
Notification,
12+
NotificationCreate,
13+
NotificationPublic,
14+
NotificationsPublic,
15+
NotificationUpdate,
16+
)
17+
18+
router = APIRouter(prefix="/notifications", tags=["notifications"])
19+
20+
21+
@router.get("/", response_model=NotificationsPublic)
22+
def read_notifications(
23+
session: SessionDep,
24+
current_user: CurrentUser,
25+
skip: int = 0,
26+
limit: int = 100,
27+
unread_only: bool = False,
28+
) -> Any:
29+
"""
30+
Retrieve notifications for the current user.
31+
"""
32+
if current_user.is_superuser:
33+
count_statement = select(func.count()).select_from(Notification)
34+
if unread_only:
35+
count_statement = count_statement.where(Notification.is_read == False)
36+
count = session.exec(count_statement).one()
37+
38+
unread_count_statement = select(func.count()).select_from(Notification).where(
39+
Notification.is_read == False
40+
)
41+
unread_count = session.exec(unread_count_statement).one()
42+
43+
statement = (
44+
select(Notification)
45+
.order_by(col(Notification.created_at).desc())
46+
.offset(skip)
47+
.limit(limit)
48+
)
49+
if unread_only:
50+
statement = statement.where(Notification.is_read == False)
51+
notifications = session.exec(statement).all()
52+
else:
53+
count_statement = (
54+
select(func.count())
55+
.select_from(Notification)
56+
.where(Notification.user_id == current_user.id)
57+
)
58+
if unread_only:
59+
count_statement = count_statement.where(Notification.is_read == False)
60+
count = session.exec(count_statement).one()
61+
62+
unread_count_statement = (
63+
select(func.count())
64+
.select_from(Notification)
65+
.where(Notification.user_id == current_user.id)
66+
.where(Notification.is_read == False)
67+
)
68+
unread_count = session.exec(unread_count_statement).one()
69+
70+
statement = (
71+
select(Notification)
72+
.where(Notification.user_id == current_user.id)
73+
.order_by(col(Notification.created_at).desc())
74+
.offset(skip)
75+
.limit(limit)
76+
)
77+
if unread_only:
78+
statement = statement.where(Notification.is_read == False)
79+
notifications = session.exec(statement).all()
80+
81+
notifications_public = [
82+
NotificationPublic.model_validate(notification)
83+
for notification in notifications
84+
]
85+
return NotificationsPublic(
86+
data=notifications_public, count=count, unread_count=unread_count
87+
)
88+
89+
90+
@router.get("/{id}", response_model=NotificationPublic)
91+
def read_notification(
92+
session: SessionDep, current_user: CurrentUser, id: uuid.UUID
93+
) -> Any:
94+
"""
95+
Get notification by ID.
96+
"""
97+
notification = session.get(Notification, id)
98+
if not notification:
99+
raise HTTPException(status_code=404, detail="Notification not found")
100+
if not current_user.is_superuser and (notification.user_id != current_user.id):
101+
raise HTTPException(status_code=403, detail="Not enough permissions")
102+
return notification
103+
104+
105+
@router.post("/", response_model=NotificationPublic)
106+
async def create_notification(
107+
*, session: SessionDep, current_user: CurrentUser, notification_in: NotificationCreate
108+
) -> Any:
109+
"""
110+
Create new notification.
111+
Only superusers can create notifications for other users.
112+
Regular users can only create notifications for themselves.
113+
"""
114+
if not current_user.is_superuser and notification_in.user_id != current_user.id:
115+
raise HTTPException(
116+
status_code=403, detail="Not enough permissions to create notifications for other users"
117+
)
118+
119+
notification = Notification.model_validate(notification_in)
120+
session.add(notification)
121+
session.commit()
122+
session.refresh(notification)
123+
124+
notification_public = NotificationPublic.model_validate(notification)
125+
await manager.send_notification(notification_public, notification.user_id)
126+
127+
return notification
128+
129+
130+
@router.put("/{id}", response_model=NotificationPublic)
131+
def update_notification(
132+
*,
133+
session: SessionDep,
134+
current_user: CurrentUser,
135+
id: uuid.UUID,
136+
notification_in: NotificationUpdate,
137+
) -> Any:
138+
"""
139+
Update a notification.
140+
"""
141+
notification = session.get(Notification, id)
142+
if not notification:
143+
raise HTTPException(status_code=404, detail="Notification not found")
144+
if not current_user.is_superuser and (notification.user_id != current_user.id):
145+
raise HTTPException(status_code=403, detail="Not enough permissions")
146+
147+
update_dict = notification_in.model_dump(exclude_unset=True)
148+
notification.sqlmodel_update(update_dict)
149+
session.add(notification)
150+
session.commit()
151+
session.refresh(notification)
152+
return notification
153+
154+
155+
@router.delete("/{id}")
156+
def delete_notification(
157+
session: SessionDep, current_user: CurrentUser, id: uuid.UUID
158+
) -> Message:
159+
"""
160+
Delete a notification.
161+
"""
162+
notification = session.get(Notification, id)
163+
if not notification:
164+
raise HTTPException(status_code=404, detail="Notification not found")
165+
if not current_user.is_superuser and (notification.user_id != current_user.id):
166+
raise HTTPException(status_code=403, detail="Not enough permissions")
167+
session.delete(notification)
168+
session.commit()
169+
return Message(message="Notification deleted successfully")
170+
171+
172+
@router.post("/{id}/mark-read", response_model=NotificationPublic)
173+
def mark_notification_read(
174+
session: SessionDep, current_user: CurrentUser, id: uuid.UUID
175+
) -> Any:
176+
"""
177+
Mark a notification as read.
178+
"""
179+
notification = session.get(Notification, id)
180+
if not notification:
181+
raise HTTPException(status_code=404, detail="Notification not found")
182+
if not current_user.is_superuser and (notification.user_id != current_user.id):
183+
raise HTTPException(status_code=403, detail="Not enough permissions")
184+
185+
notification.is_read = True
186+
session.add(notification)
187+
session.commit()
188+
session.refresh(notification)
189+
return notification
190+
191+
192+
@router.post("/mark-all-read", response_model=Message)
193+
def mark_all_notifications_read(
194+
session: SessionDep, current_user: CurrentUser
195+
) -> Message:
196+
"""
197+
Mark all notifications as read for the current user.
198+
"""
199+
statement = select(Notification).where(
200+
Notification.user_id == current_user.id,
201+
Notification.is_read == False,
202+
)
203+
notifications = session.exec(statement).all()
204+
205+
for notification in notifications:
206+
notification.is_read = True
207+
session.add(notification)
208+
209+
session.commit()
210+
return Message(message=f"Marked {len(notifications)} notifications as read")
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import logging
2+
from typing import Annotated
3+
from uuid import UUID
4+
5+
import jwt
6+
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
7+
from jwt.exceptions import InvalidTokenError
8+
from pydantic import ValidationError
9+
from sqlmodel import Session
10+
11+
from app.core import security
12+
from app.core.config import settings
13+
from app.core.db import engine
14+
from app.core.websocket import manager
15+
from app.models import TokenPayload, User
16+
17+
logger = logging.getLogger(__name__)
18+
19+
router = APIRouter(tags=["websocket"])
20+
21+
22+
def get_user_from_token(token: str) -> User:
23+
try:
24+
payload = jwt.decode(
25+
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
26+
)
27+
token_data = TokenPayload(**payload)
28+
except (InvalidTokenError, ValidationError):
29+
raise HTTPException(
30+
status_code=status.HTTP_403_FORBIDDEN,
31+
detail="Could not validate credentials",
32+
)
33+
34+
with Session(engine) as session:
35+
user = session.get(User, token_data.sub)
36+
if not user:
37+
raise HTTPException(status_code=404, detail="User not found")
38+
if not user.is_active:
39+
raise HTTPException(status_code=400, detail="Inactive user")
40+
return user
41+
42+
43+
@router.websocket("/ws/notifications")
44+
async def websocket_notifications(
45+
websocket: WebSocket,
46+
token: Annotated[str | None, Query()] = None,
47+
) -> None:
48+
"""
49+
WebSocket endpoint for real-time notifications.
50+
Clients should connect with a valid JWT token as a query parameter.
51+
"""
52+
if not token:
53+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
54+
return
55+
56+
try:
57+
user = get_user_from_token(token)
58+
except HTTPException:
59+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
60+
return
61+
62+
user_id = user.id
63+
await manager.connect(user_id, websocket)
64+
65+
try:
66+
while True:
67+
data = await websocket.receive_text()
68+
if data == "ping":
69+
await websocket.send_text("pong")
70+
except WebSocketDisconnect:
71+
manager.disconnect(user_id, websocket)
72+
logger.info(f"WebSocket disconnected for user: {user_id}")
73+
except Exception as e:
74+
manager.disconnect(user_id, websocket)
75+
logger.error(f"WebSocket error for user {user_id}: {e}")

0 commit comments

Comments
 (0)