Skip to content

Commit c5b23d1

Browse files
authored
fix: 修复Pyright静态类型检查报错 (#5437)
* refactor: 修正 Sqlite 查询、下载回调、接口重构与类型调整 * feat: 为 OneBotClient 增加 CallAction 协议与异步调用支持
1 parent 69f2fb2 commit c5b23d1

File tree

6 files changed

+44
-22
lines changed

6 files changed

+44
-22
lines changed

astrbot/core/db/sqlite.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Awaitable, Callable
55
from datetime import datetime, timedelta, timezone
66

7-
from sqlalchemy import CursorResult
7+
from sqlalchemy import CursorResult, Row
88
from sqlalchemy.ext.asyncio import AsyncSession
99
from sqlmodel import col, delete, desc, func, or_, select, text, update
1010

@@ -626,7 +626,7 @@ async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None:
626626
query = select(ApiKey).where(
627627
ApiKey.key_hash == key_hash,
628628
col(ApiKey.revoked_at).is_(None),
629-
or_(col(ApiKey.expires_at).is_(None), ApiKey.expires_at > now),
629+
or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now),
630630
)
631631
result = await session.execute(query)
632632
return result.scalar_one_or_none()
@@ -638,7 +638,7 @@ async def touch_api_key(self, key_id: str) -> None:
638638
async with session.begin():
639639
await session.execute(
640640
update(ApiKey)
641-
.where(ApiKey.key_id == key_id)
641+
.where(col(ApiKey.key_id) == key_id)
642642
.values(last_used_at=datetime.now(timezone.utc)),
643643
)
644644

@@ -649,7 +649,7 @@ async def revoke_api_key(self, key_id: str) -> bool:
649649
async with session.begin():
650650
query = (
651651
update(ApiKey)
652-
.where(ApiKey.key_id == key_id)
652+
.where(col(ApiKey.key_id) == key_id)
653653
.values(revoked_at=datetime.now(timezone.utc))
654654
)
655655
result = T.cast(CursorResult, await session.execute(query))
@@ -663,7 +663,7 @@ async def delete_api_key(self, key_id: str) -> bool:
663663
result = T.cast(
664664
CursorResult,
665665
await session.execute(
666-
delete(ApiKey).where(ApiKey.key_id == key_id)
666+
delete(ApiKey).where(col(ApiKey.key_id) == key_id)
667667
),
668668
)
669669
return result.rowcount > 0
@@ -1457,7 +1457,7 @@ def _build_platform_sessions_query(
14571457
return query
14581458

14591459
@staticmethod
1460-
def _rows_to_session_dicts(rows: list[tuple]) -> list[dict]:
1460+
def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]:
14611461
sessions_with_projects = []
14621462
for row in rows:
14631463
platform_session = row[0]

astrbot/core/platform/sources/telegram/tg_adapter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import sys
55
import uuid
6+
from typing import cast
67

78
from apscheduler.schedulers.asyncio import AsyncIOScheduler
89
from telegram import BotCommand, Update
@@ -27,7 +28,7 @@
2728
from astrbot.core.star.star import star_map
2829
from astrbot.core.star.star_handler import star_handlers_registry
2930
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
30-
from astrbot.core.utils.io import download_image_by_url
31+
from astrbot.core.utils.io import download_file
3132
from astrbot.core.utils.media_utils import convert_audio_to_wav
3233

3334
from .tg_event import TelegramPlatformEvent
@@ -380,10 +381,10 @@ async def convert_message(
380381
elif update.message.voice:
381382
file = await update.message.voice.get_file()
382383

383-
file_basename = os.path.basename(file.file_path)
384+
file_basename = os.path.basename(cast(str, file.file_path))
384385
temp_dir = get_astrbot_temp_path()
385386
temp_path = os.path.join(temp_dir, file_basename)
386-
temp_path = await download_image_by_url(file.file_path, path=temp_path)
387+
await download_file(cast(str, file.file_path), path=temp_path)
387388
path_wav = os.path.join(
388389
temp_dir,
389390
f"{file_basename}.wav",

astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import time
55
import uuid
6-
from collections.abc import Awaitable, Callable
6+
from collections.abc import Callable, Coroutine
77
from typing import Any, cast
88

99
import quart
@@ -65,7 +65,9 @@ def __init__(
6565

6666
self.event_queue = event_queue
6767

68-
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
68+
self.callback: (
69+
Callable[[BaseMessage], Coroutine[Any, Any, str | None]] | None
70+
) = None
6971
self.shutdown_event = asyncio.Event()
7072

7173
self._wx_msg_time_out = 4.0 # 微信服务器要求 5 秒内回复

astrbot/core/star/star_handler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,22 @@ def get_handlers_by_event_type(
105105
plugins_name: list[str] | None = None,
106106
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
107107

108+
@overload
109+
def get_handlers_by_event_type(
110+
self,
111+
event_type: Literal[EventType.OnPluginLoadedEvent],
112+
only_activated=True,
113+
plugins_name: list[str] | None = None,
114+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
115+
116+
@overload
117+
def get_handlers_by_event_type(
118+
self,
119+
event_type: Literal[EventType.OnPluginUnloadedEvent],
120+
only_activated=True,
121+
plugins_name: list[str] | None = None,
122+
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
123+
108124
@overload
109125
def get_handlers_by_event_type(
110126
self,

astrbot/core/utils/quoted_message/chain_parser.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from astrbot.core.platform.astr_message_event import AstrMessageEvent
2020
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
2121

22-
from .image_refs import looks_like_image_file_name, normalize_file_like_url
22+
from .image_refs import looks_like_image_file_name
2323
from .settings import SETTINGS, QuotedMessageParserSettings
2424

2525
_FORWARD_PLACEHOLDER_PATTERN = re.compile(
@@ -296,23 +296,19 @@ def _parse_onebot_segments(
296296
or "file"
297297
)
298298
text_parts.append(f"[File:{file_name}]")
299-
candidate_url = seg_data.get("url")
299+
candidate_url = seg_data.get("url", "")
300300
if (
301301
isinstance(candidate_url, str)
302302
and candidate_url.strip()
303-
and looks_like_image_file_name(normalize_file_like_url(candidate_url))
303+
and looks_like_image_file_name(candidate_url)
304304
):
305305
image_refs.append(candidate_url.strip())
306306
candidate_file = seg_data.get("file")
307307
if (
308308
isinstance(candidate_file, str)
309309
and candidate_file.strip()
310310
and looks_like_image_file_name(
311-
normalize_file_like_url(
312-
seg_data.get("name")
313-
or seg_data.get("file_name")
314-
or candidate_file
315-
)
311+
seg_data.get("name") or seg_data.get("file_name") or candidate_file
316312
)
317313
):
318314
image_refs.append(candidate_file.strip())
@@ -368,7 +364,9 @@ def _extract_text_forward_ids_and_images_from_forward_nodes(
368364
if not isinstance(node, dict):
369365
continue
370366

371-
sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
367+
sender = node.get("sender")
368+
if not isinstance(sender, dict):
369+
sender = {}
372370
sender_name = (
373371
sender.get("nickname")
374372
or sender.get("card")

astrbot/core/utils/quoted_message/onebot_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from collections.abc import Awaitable
4+
from typing import Any, Protocol
45

56
from astrbot import logger
67
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -17,6 +18,10 @@ def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
1718
return ret
1819

1920

21+
class CallAction(Protocol):
22+
def __call__(self, action: str, **params: Any) -> Awaitable[Any] | Any: ...
23+
24+
2025
class OneBotClient:
2126
def __init__(
2227
self,
@@ -27,7 +32,7 @@ def __init__(
2732
self._settings = settings
2833

2934
@staticmethod
30-
def _resolve_call_action(event: AstrMessageEvent):
35+
def _resolve_call_action(event: AstrMessageEvent) -> CallAction | None:
3136
bot = getattr(event, "bot", None)
3237
api = getattr(bot, "api", None)
3338
call_action = getattr(api, "call_action", None)

0 commit comments

Comments
 (0)