|
| 1 | +"""对外直连召回 SSE(LINK-40)的会话鉴权与并发治理。 |
| 2 | +
|
| 3 | +外部用户态 Recall 入口归属 Java:Java 用 Sa-Token 鉴权并校验 dataset 归属后,签发 |
| 4 | +**短期 session token**;前端凭该 token 直连 Python ``POST /api/v1/recall/stream``。 |
| 5 | +本模块提供: |
| 6 | +
|
| 7 | +- ``SessionAuthContext``:从可信 claims 解析出的请求上下文; |
| 8 | +- ``verify_session_token``:FastAPI 依赖,用**独立密钥**验签 + 校验 iss/aud/scope/exp; |
| 9 | +- ``acquire_stream_slot`` / ``release_stream_slot``:按 ``user_id`` 的并发流计数。 |
| 10 | +
|
| 11 | +与内部端点(``internal_auth.py``)的关键差异:面向浏览器、密钥/受众独立;token |
| 12 | +**短期可复用**——只校验 ``exp``,不做一次性消费 / 防重放 / 撤销,资源滥用由并发上限封顶。 |
| 13 | +设计依据见 .specs/recall-direct-sse/{brief,technical_design}.md。 |
| 14 | +""" |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from dataclasses import dataclass |
| 19 | + |
| 20 | +import jwt |
| 21 | +from fastapi import Request |
| 22 | +from loguru import logger |
| 23 | + |
| 24 | +from src.api.internal_auth import ( |
| 25 | + CODE_SESSION_UNAUTHORIZED, |
| 26 | + RecallApiError, |
| 27 | + _request_id, |
| 28 | +) |
| 29 | +from src.cache.redis_client import redis_client |
| 30 | +from src.config import settings |
| 31 | + |
| 32 | +# 并发计数 key 前缀;按 user_id 分桶,跨 worker / 实例共享。 |
| 33 | +_CONCURRENT_KEY_PREFIX = "recall:concurrent:" |
| 34 | + |
| 35 | + |
| 36 | +@dataclass(frozen=True) |
| 37 | +class SessionAuthContext: |
| 38 | + """从 session token claims 解析出的可信请求上下文。 |
| 39 | +
|
| 40 | + Attributes: |
| 41 | + user_id: 来自 claims ``sub`` 的权威用户身份(正整数)。 |
| 42 | + dataset_ids: claims 授权的数据集范围;``None`` 或空表示全库授权。 |
| 43 | + request_id: 本次请求标识;取 ``X-Request-Id``,缺省时生成。 |
| 44 | + """ |
| 45 | + |
| 46 | + user_id: int |
| 47 | + dataset_ids: list[int] | None |
| 48 | + request_id: str |
| 49 | + |
| 50 | + |
| 51 | +def _extract_session_token(request: Request) -> str: |
| 52 | + """从 ``Authorization: Bearer`` 提取 session token。 |
| 53 | +
|
| 54 | + 缺失或格式不符抛 ``RECALL_SESSION_UNAUTHORIZED``(区别于内部端点的错误码)。 |
| 55 | + """ |
| 56 | + header = request.headers.get("Authorization") |
| 57 | + if not header or not header.startswith("Bearer "): |
| 58 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "missing session credential") |
| 59 | + token = header[len("Bearer ") :].strip() |
| 60 | + if not token: |
| 61 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "missing session credential") |
| 62 | + return token |
| 63 | + |
| 64 | + |
| 65 | +def _context_from_session_claims(claims: dict, request_id: str) -> SessionAuthContext: |
| 66 | + """从可信 claims 装配上下文;身份只取 claims,不信任前端自报。""" |
| 67 | + raw_sub = claims.get("sub") |
| 68 | + try: |
| 69 | + user_id = int(raw_sub) |
| 70 | + except (TypeError, ValueError): |
| 71 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "invalid subject in credential") |
| 72 | + if user_id <= 0: |
| 73 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "invalid subject in credential") |
| 74 | + |
| 75 | + dataset_ids = claims.get("dataset_ids") |
| 76 | + if dataset_ids is not None and not isinstance(dataset_ids, list): |
| 77 | + raise RecallApiError( |
| 78 | + 401, CODE_SESSION_UNAUTHORIZED, "invalid dataset_ids in credential" |
| 79 | + ) |
| 80 | + |
| 81 | + return SessionAuthContext( |
| 82 | + user_id=user_id, dataset_ids=dataset_ids, request_id=request_id |
| 83 | + ) |
| 84 | + |
| 85 | + |
| 86 | +async def verify_session_token(request: Request) -> SessionAuthContext: |
| 87 | + """FastAPI 依赖:校验 Java 签发的 session token,产出 ``SessionAuthContext``。 |
| 88 | +
|
| 89 | + 校验链(任一失败 → ``RecallApiError(401, RECALL_SESSION_UNAUTHORIZED)``): |
| 90 | + Bearer token → HS256 验签(**独立 session 密钥**)+ iss/aud/exp(PyJWT 内置) |
| 91 | + → scope(手动)→ sub→user_id。token 短期可复用,无一次性消费步骤——有效期内重复 |
| 92 | + 建连均放行(断线重连可复用未过期 token)。 |
| 93 | +
|
| 94 | + ``RECALL_SESSION_AUTH_ENABLED=False`` 仅本地联调:跳过验签但仍解析 claims 取身份; |
| 95 | + 生产恒开启。 |
| 96 | + """ |
| 97 | + request_id = _request_id(request) |
| 98 | + token = _extract_session_token(request) |
| 99 | + |
| 100 | + if not settings.RECALL_SESSION_AUTH_ENABLED: |
| 101 | + # 本地联调:不验签,仅解析 claims 取身份。生产恒开启,不会走到这里。 |
| 102 | + logger.warning( |
| 103 | + "[recall-session] auth disabled; skipping JWT verification request_id={}", |
| 104 | + request_id, |
| 105 | + ) |
| 106 | + claims = jwt.decode(token, options={"verify_signature": False}) |
| 107 | + return _context_from_session_claims(claims, request_id) |
| 108 | + |
| 109 | + try: |
| 110 | + claims = jwt.decode( |
| 111 | + token, |
| 112 | + settings.RECALL_SESSION_JWT_SECRET, |
| 113 | + algorithms=["HS256"], |
| 114 | + audience=settings.RECALL_SESSION_JWT_AUDIENCE, |
| 115 | + issuer=settings.RECALL_SESSION_JWT_ISSUER, |
| 116 | + options={"require": ["exp"]}, |
| 117 | + ) |
| 118 | + except jwt.PyJWTError as exc: |
| 119 | + logger.info("[recall-session] JWT rejected request_id={}: {}", request_id, exc) |
| 120 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "invalid or expired credential") |
| 121 | + |
| 122 | + if claims.get("scope") != settings.RECALL_SESSION_JWT_SCOPE: |
| 123 | + raise RecallApiError(401, CODE_SESSION_UNAUTHORIZED, "credential scope not permitted") |
| 124 | + |
| 125 | + return _context_from_session_claims(claims, request_id) |
| 126 | + |
| 127 | + |
| 128 | +def _concurrent_key(user_id: int) -> str: |
| 129 | + return f"{_CONCURRENT_KEY_PREFIX}{user_id}" |
| 130 | + |
| 131 | + |
| 132 | +async def acquire_stream_slot(user_id: int) -> bool: |
| 133 | + """占用一个并发流名额;返回是否成功(False → 调用方应回 429)。 |
| 134 | +
|
| 135 | + INCR 先占位再判断,保证多 worker 下不超卖;超过上限则 DECR 回退。key 设 |
| 136 | + ``2×stream_timeout`` 安全 TTL,兜底进程异常退出未 release 造成的名额泄漏。 |
| 137 | +
|
| 138 | + Redis 不可用时 **fail-open**(放行 + 告警):去一次性后 Redis 仅做资源保护、不再 |
| 139 | + 承载安全语义,短暂失去并发限流好于阻断全部召回。 |
| 140 | + """ |
| 141 | + key = _concurrent_key(user_id) |
| 142 | + safety_ttl = max(1, settings.RECALL_STREAM_TIMEOUT_MS // 1000 * 2) |
| 143 | + try: |
| 144 | + count = await redis_client.incr(key) |
| 145 | + await redis_client.expire(key, safety_ttl) |
| 146 | + except Exception: # noqa: BLE001 - Redis 故障不阻断召回,fail-open |
| 147 | + logger.warning( |
| 148 | + "[recall-session] redis unavailable on acquire, fail-open user_id={}", user_id |
| 149 | + ) |
| 150 | + return True |
| 151 | + |
| 152 | + if count > settings.RECALL_SESSION_MAX_CONCURRENT: |
| 153 | + # 超卖,回退占位并拒绝。 |
| 154 | + try: |
| 155 | + await redis_client.decr(key) |
| 156 | + except Exception: # noqa: BLE001 - 回退失败由 TTL 兜底 |
| 157 | + logger.warning("[recall-session] redis decr failed on rollback user_id={}", user_id) |
| 158 | + return False |
| 159 | + return True |
| 160 | + |
| 161 | + |
| 162 | +async def release_stream_slot(user_id: int) -> None: |
| 163 | + """释放一个并发流名额;在流结束 / 断连的 finally 中调用。 |
| 164 | +
|
| 165 | + DECR 后若计数为负(异常路径下的重复释放),重置回 0,避免计数漂移把后续请求误放行。 |
| 166 | + Redis 故障静默忽略,由 key 的安全 TTL 兜底回收。 |
| 167 | + """ |
| 168 | + key = _concurrent_key(user_id) |
| 169 | + try: |
| 170 | + remaining = await redis_client.decr(key) |
| 171 | + if remaining < 0: |
| 172 | + await redis_client.set(key, "0") |
| 173 | + except Exception: # noqa: BLE001 - 释放失败由 TTL 兜底,不影响主流程 |
| 174 | + logger.warning("[recall-session] redis unavailable on release user_id={}", user_id) |
0 commit comments