Skip to content

Commit a404436

Browse files
feat: astrbot http api (#5280)
* feat: astrbot http api * Potential fix for code scanning alert no. 34: Use of a broken or weak cryptographic hashing algorithm on sensitive data Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix: improve error handling for missing attachment path in file upload * feat: implement paginated retrieval of platform sessions for creators * feat: refactor attachment directory handling in ChatRoute * feat: update API endpoint paths for file and message handling * feat: add documentation link to API key management section in settings * feat: update API key scopes and related configurations in API routes and tests * feat: enhance API key expiration options and add warning for permanent keys * feat: add UTC normalization and serialization for API key timestamps * feat: implement chat session management and validation for usernames * feat: ignore session_id type chunks in message processing --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent bcb12a0 commit a404436

File tree

14 files changed

+2315
-59
lines changed

14 files changed

+2315
-59
lines changed

astrbot/core/db/__init__.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
99

1010
from astrbot.core.db.po import (
11+
ApiKey,
1112
Attachment,
1213
ChatUIProject,
1314
CommandConfig,
@@ -248,6 +249,55 @@ async def delete_attachments(self, attachment_ids: list[str]) -> int:
248249
"""
249250
...
250251

252+
@abc.abstractmethod
253+
async def create_api_key(
254+
self,
255+
name: str,
256+
key_hash: str,
257+
key_prefix: str,
258+
scopes: list[str] | None,
259+
created_by: str,
260+
expires_at: datetime.datetime | None = None,
261+
) -> ApiKey:
262+
"""Create a new API key record."""
263+
...
264+
265+
@abc.abstractmethod
266+
async def list_api_keys(self) -> list[ApiKey]:
267+
"""List all API keys."""
268+
...
269+
270+
@abc.abstractmethod
271+
async def get_api_key_by_id(self, key_id: str) -> ApiKey | None:
272+
"""Get an API key by key_id."""
273+
...
274+
275+
@abc.abstractmethod
276+
async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None:
277+
"""Get an active API key by hash (not revoked, not expired)."""
278+
...
279+
280+
@abc.abstractmethod
281+
async def touch_api_key(self, key_id: str) -> None:
282+
"""Update last_used_at of an API key."""
283+
...
284+
285+
@abc.abstractmethod
286+
async def revoke_api_key(self, key_id: str) -> bool:
287+
"""Revoke an API key.
288+
289+
Returns True when the key exists and is updated.
290+
"""
291+
...
292+
293+
@abc.abstractmethod
294+
async def delete_api_key(self, key_id: str) -> bool:
295+
"""Delete an API key.
296+
297+
Returns True when the key exists and is deleted.
298+
"""
299+
...
300+
251301
@abc.abstractmethod
252302
async def insert_persona(
253303
self,
@@ -608,6 +658,22 @@ async def get_platform_sessions_by_creator(
608658
"""
609659
...
610660

661+
@abc.abstractmethod
662+
async def get_platform_sessions_by_creator_paginated(
663+
self,
664+
creator: str,
665+
platform_id: str | None = None,
666+
page: int = 1,
667+
page_size: int = 20,
668+
exclude_project_sessions: bool = False,
669+
) -> tuple[list[dict], int]:
670+
"""Get paginated platform sessions and total count for a creator.
671+
672+
Returns:
673+
tuple[list[dict], int]: (sessions_with_project_info, total_count)
674+
"""
675+
...
676+
611677
@abc.abstractmethod
612678
async def update_platform_session(
613679
self,

astrbot/core/db/po.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,43 @@ class Attachment(TimestampMixin, SQLModel, table=True):
288288
)
289289

290290

291+
class ApiKey(TimestampMixin, SQLModel, table=True):
292+
"""API keys used by external developers to access Open APIs."""
293+
294+
__tablename__: str = "api_keys"
295+
296+
inner_id: int | None = Field(
297+
primary_key=True,
298+
sa_column_kwargs={"autoincrement": True},
299+
default=None,
300+
)
301+
key_id: str = Field(
302+
max_length=36,
303+
nullable=False,
304+
unique=True,
305+
default_factory=lambda: str(uuid.uuid4()),
306+
)
307+
name: str = Field(max_length=255, nullable=False)
308+
key_hash: str = Field(max_length=128, nullable=False, unique=True)
309+
key_prefix: str = Field(max_length=24, nullable=False)
310+
scopes: list | None = Field(default=None, sa_type=JSON)
311+
created_by: str = Field(max_length=255, nullable=False)
312+
last_used_at: datetime | None = Field(default=None)
313+
expires_at: datetime | None = Field(default=None)
314+
revoked_at: datetime | None = Field(default=None)
315+
316+
__table_args__ = (
317+
UniqueConstraint(
318+
"key_id",
319+
name="uix_api_key_id",
320+
),
321+
UniqueConstraint(
322+
"key_hash",
323+
name="uix_api_key_hash",
324+
),
325+
)
326+
327+
291328
class ChatUIProject(TimestampMixin, SQLModel, table=True):
292329
"""This class represents projects for organizing ChatUI conversations.
293330

astrbot/core/db/sqlite.py

Lines changed: 180 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from astrbot.core.db import BaseDatabase
1212
from astrbot.core.db.po import (
13+
ApiKey,
1314
Attachment,
1415
ChatUIProject,
1516
CommandConfig,
@@ -573,6 +574,100 @@ async def delete_attachments(self, attachment_ids: list[str]) -> int:
573574
result = T.cast(CursorResult, await session.execute(query))
574575
return result.rowcount
575576

577+
async def create_api_key(
578+
self,
579+
name: str,
580+
key_hash: str,
581+
key_prefix: str,
582+
scopes: list[str] | None,
583+
created_by: str,
584+
expires_at: datetime | None = None,
585+
) -> ApiKey:
586+
"""Create a new API key record."""
587+
async with self.get_db() as session:
588+
session: AsyncSession
589+
async with session.begin():
590+
api_key = ApiKey(
591+
name=name,
592+
key_hash=key_hash,
593+
key_prefix=key_prefix,
594+
scopes=scopes,
595+
created_by=created_by,
596+
expires_at=expires_at,
597+
)
598+
session.add(api_key)
599+
await session.flush()
600+
await session.refresh(api_key)
601+
return api_key
602+
603+
async def list_api_keys(self) -> list[ApiKey]:
604+
"""List all API keys."""
605+
async with self.get_db() as session:
606+
session: AsyncSession
607+
result = await session.execute(
608+
select(ApiKey).order_by(desc(ApiKey.created_at))
609+
)
610+
return list(result.scalars().all())
611+
612+
async def get_api_key_by_id(self, key_id: str) -> ApiKey | None:
613+
"""Get an API key by key_id."""
614+
async with self.get_db() as session:
615+
session: AsyncSession
616+
result = await session.execute(
617+
select(ApiKey).where(ApiKey.key_id == key_id)
618+
)
619+
return result.scalar_one_or_none()
620+
621+
async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None:
622+
"""Get an active API key by hash (not revoked, not expired)."""
623+
async with self.get_db() as session:
624+
session: AsyncSession
625+
now = datetime.now(timezone.utc)
626+
query = select(ApiKey).where(
627+
ApiKey.key_hash == key_hash,
628+
col(ApiKey.revoked_at).is_(None),
629+
or_(col(ApiKey.expires_at).is_(None), ApiKey.expires_at > now),
630+
)
631+
result = await session.execute(query)
632+
return result.scalar_one_or_none()
633+
634+
async def touch_api_key(self, key_id: str) -> None:
635+
"""Update last_used_at of an API key."""
636+
async with self.get_db() as session:
637+
session: AsyncSession
638+
async with session.begin():
639+
await session.execute(
640+
update(ApiKey)
641+
.where(ApiKey.key_id == key_id)
642+
.values(last_used_at=datetime.now(timezone.utc)),
643+
)
644+
645+
async def revoke_api_key(self, key_id: str) -> bool:
646+
"""Revoke an API key."""
647+
async with self.get_db() as session:
648+
session: AsyncSession
649+
async with session.begin():
650+
query = (
651+
update(ApiKey)
652+
.where(ApiKey.key_id == key_id)
653+
.values(revoked_at=datetime.now(timezone.utc))
654+
)
655+
result = T.cast(CursorResult, await session.execute(query))
656+
return result.rowcount > 0
657+
658+
async def delete_api_key(self, key_id: str) -> bool:
659+
"""Delete an API key."""
660+
async with self.get_db() as session:
661+
session: AsyncSession
662+
async with session.begin():
663+
result = T.cast(
664+
CursorResult,
665+
await session.execute(
666+
delete(ApiKey).where(ApiKey.key_id == key_id)
667+
),
668+
)
669+
return result.rowcount > 0
670+
576671
async def insert_persona(
577672
self,
578673
persona_id,
@@ -1317,58 +1412,102 @@ async def get_platform_sessions_by_creator(
13171412
13181413
Returns a list of dicts containing session info and project info (if session belongs to a project).
13191414
"""
1415+
(
1416+
sessions_with_projects,
1417+
_,
1418+
) = await self.get_platform_sessions_by_creator_paginated(
1419+
creator=creator,
1420+
platform_id=platform_id,
1421+
page=page,
1422+
page_size=page_size,
1423+
exclude_project_sessions=False,
1424+
)
1425+
return sessions_with_projects
1426+
1427+
@staticmethod
1428+
def _build_platform_sessions_query(
1429+
creator: str,
1430+
platform_id: str | None = None,
1431+
exclude_project_sessions: bool = False,
1432+
):
1433+
query = (
1434+
select(
1435+
PlatformSession,
1436+
col(ChatUIProject.project_id),
1437+
col(ChatUIProject.title).label("project_title"),
1438+
col(ChatUIProject.emoji).label("project_emoji"),
1439+
)
1440+
.outerjoin(
1441+
SessionProjectRelation,
1442+
col(PlatformSession.session_id)
1443+
== col(SessionProjectRelation.session_id),
1444+
)
1445+
.outerjoin(
1446+
ChatUIProject,
1447+
col(SessionProjectRelation.project_id) == col(ChatUIProject.project_id),
1448+
)
1449+
.where(col(PlatformSession.creator) == creator)
1450+
)
1451+
1452+
if platform_id:
1453+
query = query.where(PlatformSession.platform_id == platform_id)
1454+
if exclude_project_sessions:
1455+
query = query.where(col(ChatUIProject.project_id).is_(None))
1456+
1457+
return query
1458+
1459+
@staticmethod
1460+
def _rows_to_session_dicts(rows: list[tuple]) -> list[dict]:
1461+
sessions_with_projects = []
1462+
for row in rows:
1463+
platform_session = row[0]
1464+
project_id = row[1]
1465+
project_title = row[2]
1466+
project_emoji = row[3]
1467+
1468+
session_dict = {
1469+
"session": platform_session,
1470+
"project_id": project_id,
1471+
"project_title": project_title,
1472+
"project_emoji": project_emoji,
1473+
}
1474+
sessions_with_projects.append(session_dict)
1475+
1476+
return sessions_with_projects
1477+
1478+
async def get_platform_sessions_by_creator_paginated(
1479+
self,
1480+
creator: str,
1481+
platform_id: str | None = None,
1482+
page: int = 1,
1483+
page_size: int = 20,
1484+
exclude_project_sessions: bool = False,
1485+
) -> tuple[list[dict], int]:
1486+
"""Get paginated Platform sessions for a creator with total count."""
13201487
async with self.get_db() as session:
13211488
session: AsyncSession
13221489
offset = (page - 1) * page_size
13231490

1324-
# LEFT JOIN with SessionProjectRelation and ChatUIProject to get project info
1325-
query = (
1326-
select(
1327-
PlatformSession,
1328-
col(ChatUIProject.project_id),
1329-
col(ChatUIProject.title).label("project_title"),
1330-
col(ChatUIProject.emoji).label("project_emoji"),
1331-
)
1332-
.outerjoin(
1333-
SessionProjectRelation,
1334-
col(PlatformSession.session_id)
1335-
== col(SessionProjectRelation.session_id),
1336-
)
1337-
.outerjoin(
1338-
ChatUIProject,
1339-
col(SessionProjectRelation.project_id)
1340-
== col(ChatUIProject.project_id),
1341-
)
1342-
.where(col(PlatformSession.creator) == creator)
1491+
base_query = self._build_platform_sessions_query(
1492+
creator=creator,
1493+
platform_id=platform_id,
1494+
exclude_project_sessions=exclude_project_sessions,
13431495
)
13441496

1345-
if platform_id:
1346-
query = query.where(PlatformSession.platform_id == platform_id)
1497+
total_result = await session.execute(
1498+
select(func.count()).select_from(base_query.subquery())
1499+
)
1500+
total = int(total_result.scalar_one() or 0)
13471501

1348-
query = (
1349-
query.order_by(desc(PlatformSession.updated_at))
1502+
result_query = (
1503+
base_query.order_by(desc(PlatformSession.updated_at))
13501504
.offset(offset)
13511505
.limit(page_size)
13521506
)
1353-
result = await session.execute(query)
1354-
1355-
# Convert to list of dicts with session and project info
1356-
sessions_with_projects = []
1357-
for row in result.all():
1358-
platform_session = row[0]
1359-
project_id = row[1]
1360-
project_title = row[2]
1361-
project_emoji = row[3]
1362-
1363-
session_dict = {
1364-
"session": platform_session,
1365-
"project_id": project_id,
1366-
"project_title": project_title,
1367-
"project_emoji": project_emoji,
1368-
}
1369-
sessions_with_projects.append(session_dict)
1507+
result = await session.execute(result_query)
13701508

1371-
return sessions_with_projects
1509+
sessions_with_projects = self._rows_to_session_dicts(result.all())
1510+
return sessions_with_projects, total
13721511

13731512
async def update_platform_session(
13741513
self,

0 commit comments

Comments
 (0)