Skip to content

Commit b7af106

Browse files
committed
feat: add TOTP two-factor authentication for dashboard login
1 parent 7d72e3a commit b7af106

31 files changed

Lines changed: 2201 additions & 123 deletions

astrbot/core/config/default.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@
252252
"host": "0.0.0.0",
253253
"port": 6185,
254254
"disable_access_log": True,
255+
"totp": {
256+
"enable": False,
257+
"secret": "",
258+
"recovery_code_hash": "",
259+
},
255260
"ssl": {
256261
"enable": False,
257262
"cert_file": "",
@@ -4180,6 +4185,12 @@
41804185
"type": "bool",
41814186
"hint": "启用后,WebUI 将直接使用 HTTPS 提供服务。",
41824187
},
4188+
"dashboard.totp.enable": {
4189+
"description": "启用 WebUI TOTP 双因素认证",
4190+
"type": "bool",
4191+
"hint": "启用后,登录 WebUI 需要额外输入验证码。",
4192+
"_special": "dashboard_totp_manager",
4193+
},
41834194
"dashboard.ssl.cert_file": {
41844195
"description": "SSL 证书文件路径",
41854196
"type": "string",

astrbot/core/db/po.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,21 @@ class ApiKey(TimestampMixin, SQLModel, table=True):
382382
)
383383

384384

385+
class DashboardTrustedDevice(TimestampMixin, SQLModel, table=True):
386+
"""Trusted dashboard device token used to skip TOTP for a limited time."""
387+
388+
__tablename__: str = "dashboard_trusted_devices"
389+
390+
id: int | None = Field(
391+
default=None,
392+
primary_key=True,
393+
sa_column_kwargs={"autoincrement": True},
394+
)
395+
token_hash: str = Field(max_length=64, nullable=False, unique=True, index=True)
396+
totp_secret_hash: str = Field(max_length=64, nullable=False, index=True)
397+
expires_at: datetime = Field(nullable=False, index=True)
398+
399+
385400
class ChatUIProject(TimestampMixin, SQLModel, table=True):
386401
"""This class represents projects for organizing ChatUI conversations.
387402

astrbot/core/utils/totp.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import base64
5+
import datetime
6+
import hashlib
7+
import hmac
8+
import secrets
9+
10+
import pyotp
11+
from sqlmodel import col, delete, select
12+
13+
from astrbot.core.db.po import DashboardTrustedDevice
14+
15+
TOTP_TRUSTED_DEVICE_COOKIE_NAME = "astrbot_totp_trusted_device"
16+
TOTP_TRUSTED_DEVICE_MAX_AGE = 30 * 24 * 60 * 60
17+
RECOVERY_CODE_GROUP_COUNT = 4
18+
RECOVERY_CODE_GROUP_LENGTH = 8
19+
RECOVERY_CODE_LENGTH = RECOVERY_CODE_GROUP_COUNT * RECOVERY_CODE_GROUP_LENGTH
20+
_RECOVERY_CODE_KDF_ITERATIONS = 600_000
21+
_RECOVERY_CODE_KDF_SALT_BYTES = 16
22+
_RECOVERY_CODE_KDF_ALGORITHM = "pbkdf2_sha256"
23+
24+
_last_totp_timecode: dict[str, int] = {}
25+
_totp_replay_lock = asyncio.Lock()
26+
27+
28+
def _get_totp_config(config) -> dict:
29+
totp_config = config.get("dashboard", {}).get("totp", {})
30+
return totp_config if isinstance(totp_config, dict) else {}
31+
32+
33+
def is_totp_enabled(config) -> bool:
34+
"""TOTP is fully configured and operational (enable + secret + recovery hash all present)."""
35+
totp_config = _get_totp_config(config)
36+
if not totp_config.get("enable", False):
37+
return False
38+
secret = totp_config.get("secret", "")
39+
if not isinstance(secret, str) or not secret.strip():
40+
return False
41+
recovery_code_hash = totp_config.get("recovery_code_hash", "")
42+
if not isinstance(recovery_code_hash, str) or not recovery_code_hash.strip():
43+
return False
44+
return True
45+
46+
47+
def _get_verified_totp_timecode(secret: str, code: str) -> int | None:
48+
code = code.strip()
49+
try:
50+
totp = pyotp.TOTP(secret.strip())
51+
now = datetime.datetime.now()
52+
for offset in (-1, 0, 1):
53+
candidate_time = now + datetime.timedelta(seconds=offset * totp.interval)
54+
if hmac.compare_digest(str(totp.at(candidate_time)), code):
55+
return int(totp.timecode(candidate_time))
56+
except Exception:
57+
return None
58+
return None
59+
60+
61+
async def consume_totp_code(secret: str, code: str) -> bool:
62+
global _last_totp_timecode
63+
timecode = _get_verified_totp_timecode(secret, code)
64+
if timecode is None:
65+
return False
66+
secret = secret.strip()
67+
async with _totp_replay_lock:
68+
if _last_totp_timecode.get(secret, -1) >= timecode:
69+
return False
70+
_last_totp_timecode[secret] = timecode
71+
return True
72+
73+
74+
async def consume_configured_totp_code(config, code: str) -> bool:
75+
if not is_totp_enabled(config):
76+
return False
77+
secret = _get_totp_config(config).get("secret", "")
78+
return await consume_totp_code(secret, code)
79+
80+
81+
def _hash_totp_trusted_device_token(config, token: str) -> str:
82+
jwt_secret = config["dashboard"].get("jwt_secret", "")
83+
if not isinstance(jwt_secret, str) or not jwt_secret:
84+
return ""
85+
return hmac.new(
86+
jwt_secret.encode("utf-8"),
87+
token.encode("utf-8"),
88+
hashlib.sha256,
89+
).hexdigest()
90+
91+
92+
def _hash_totp_secret(config) -> str:
93+
secret = _get_totp_config(config).get("secret", "")
94+
if not isinstance(secret, str) or not secret.strip():
95+
return ""
96+
return hashlib.sha256(secret.strip().encode("utf-8")).hexdigest()
97+
98+
99+
async def is_totp_trusted_device_valid(config, db, cookie_token: str) -> bool:
100+
if not cookie_token:
101+
return False
102+
token_hash = _hash_totp_trusted_device_token(config, cookie_token)
103+
totp_secret_hash = _hash_totp_secret(config)
104+
if not token_hash or not totp_secret_hash:
105+
return False
106+
107+
await _cleanup_expired_totp_trusted_devices(db)
108+
async with db.get_db() as session:
109+
result = await session.execute(
110+
select(DashboardTrustedDevice).where(
111+
col(DashboardTrustedDevice.token_hash) == token_hash,
112+
col(DashboardTrustedDevice.totp_secret_hash) == totp_secret_hash,
113+
col(DashboardTrustedDevice.expires_at)
114+
> datetime.datetime.now(datetime.timezone.utc),
115+
)
116+
)
117+
return result.scalar_one_or_none() is not None
118+
119+
120+
async def issue_totp_trusted_device(config, db) -> str | None:
121+
"""Issue a trusted device token, save to DB, and return the raw token for cookie."""
122+
raw_token = secrets.token_urlsafe(48)
123+
token_hash = _hash_totp_trusted_device_token(config, raw_token)
124+
totp_secret_hash = _hash_totp_secret(config)
125+
if not token_hash or not totp_secret_hash:
126+
return None
127+
128+
expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
129+
seconds=TOTP_TRUSTED_DEVICE_MAX_AGE
130+
)
131+
async with db.get_db() as session:
132+
async with session.begin():
133+
await session.execute(
134+
delete(DashboardTrustedDevice).where(
135+
col(DashboardTrustedDevice.token_hash) == token_hash
136+
)
137+
)
138+
trusted_device = DashboardTrustedDevice.model_validate(
139+
{
140+
"token_hash": token_hash,
141+
"totp_secret_hash": totp_secret_hash,
142+
"expires_at": expires_at,
143+
}
144+
)
145+
session.add(trusted_device)
146+
return raw_token
147+
148+
149+
async def _cleanup_expired_totp_trusted_devices(db) -> None:
150+
async with db.get_db() as session:
151+
async with session.begin():
152+
await session.execute(
153+
delete(DashboardTrustedDevice).where(
154+
col(DashboardTrustedDevice.expires_at)
155+
<= datetime.datetime.now(datetime.timezone.utc)
156+
)
157+
)
158+
159+
160+
async def revoke_user_trusted_devices(db) -> None:
161+
async with db.get_db() as session:
162+
async with session.begin():
163+
await session.execute(delete(DashboardTrustedDevice))
164+
165+
166+
def generate_recovery_code() -> tuple[str, str]:
167+
raw = secrets.token_bytes(20)
168+
recovery_code = base64.b32encode(raw).decode("ascii").rstrip("=")
169+
salt = secrets.token_hex(_RECOVERY_CODE_KDF_SALT_BYTES)
170+
digest = hashlib.pbkdf2_hmac(
171+
"sha256",
172+
recovery_code.encode("utf-8"),
173+
bytes.fromhex(salt),
174+
_RECOVERY_CODE_KDF_ITERATIONS,
175+
).hex()
176+
kdf_hash = f"{_RECOVERY_CODE_KDF_ALGORITHM}${_RECOVERY_CODE_KDF_ITERATIONS}${salt}${digest}"
177+
parts = [
178+
recovery_code[i : i + RECOVERY_CODE_GROUP_LENGTH]
179+
for i in range(0, len(recovery_code), RECOVERY_CODE_GROUP_LENGTH)
180+
]
181+
return "-".join(parts), kdf_hash
182+
183+
184+
def verify_recovery_code(config, code: str) -> bool:
185+
"""Verify a recovery code against configured recovery_code_hash (PBKDF2)."""
186+
cleaned = "".join(char for char in code.upper() if char.isalnum())
187+
if len(cleaned) != RECOVERY_CODE_LENGTH:
188+
return False
189+
totp_config = _get_totp_config(config)
190+
stored_hash = totp_config.get("recovery_code_hash", "")
191+
if not isinstance(stored_hash, str) or not stored_hash:
192+
return False
193+
194+
parts = stored_hash.split("$")
195+
if len(parts) != 4 or parts[0] != _RECOVERY_CODE_KDF_ALGORITHM:
196+
return False
197+
try:
198+
iterations = int(parts[1])
199+
salt = parts[2]
200+
expected_digest = parts[3]
201+
except (ValueError, IndexError):
202+
return False
203+
204+
candidate = hashlib.pbkdf2_hmac(
205+
"sha256",
206+
cleaned.encode("utf-8"),
207+
bytes.fromhex(salt),
208+
iterations,
209+
).hex()
210+
return hmac.compare_digest(candidate, expected_digest)

0 commit comments

Comments
 (0)