Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import botpy
import botpy.message
from botpy import Client
from botpy.gateway import BotWebSocket

from astrbot import logger
from astrbot.api.event import MessageChain
Expand All @@ -37,11 +38,37 @@
logging.root.removeHandler(handler)


class ManagedBotWebSocket(BotWebSocket):
def __init__(self, session, connection: Any, client: botClient):
super().__init__(session, connection)
self._client = client

async def on_closed(self, close_status_code, close_msg):
if self._client.is_shutting_down:
logger.debug("[QQOfficial] Ignore websocket reconnect during shutdown.")
return
await super().on_closed(close_status_code, close_msg)

async def close(self) -> None:
self._can_reconnect = False
if self._conn is not None and not self._conn.closed:
await self._conn.close()


# QQ 机器人官方框架
class botClient(Client):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._shutting_down = False
self._active_websockets: set[ManagedBotWebSocket] = set()

def set_platform(self, platform: QQOfficialPlatformAdapter) -> None:
self.platform = platform

@property
def is_shutting_down(self) -> bool:
return self._shutting_down or self.is_closed()

# 收到群消息
async def on_group_at_message_create(
self, message: botpy.message.GroupMessage
Expand Down Expand Up @@ -100,6 +127,30 @@ def _commit(self, abm: AstrBotMessage) -> None:
),
)

async def bot_connect(self, session) -> None:
logger.info("[QQOfficial] Websocket session starting.")

websocket = ManagedBotWebSocket(session, self._connection, self)
self._active_websockets.add(websocket)
try:
await websocket.ws_connect()
except Exception as e:
if not self.is_shutting_down:
await websocket.on_error(e)
finally:
self._active_websockets.discard(websocket)

async def shutdown(self) -> None:
if self.is_shutting_down:
return

self._shutting_down = True
await asyncio.gather(
*(websocket.close() for websocket in list(self._active_websockets)),
return_exceptions=True,
)
await self.close()


@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器")
class QQOfficialPlatformAdapter(Platform):
Expand Down Expand Up @@ -542,5 +593,5 @@ def get_client(self) -> botClient:
return self.client

async def terminate(self) -> None:
await self.client.close()
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
await self.client.shutdown()
logger.info("QQ 官方机器人接口 适配器已被关闭")
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ async def terminate(self) -> None:
f"Exception occurred during QQOfficialWebhook server shutdown: {exc}",
exc_info=True,
)
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
logger.info("QQ 机器人官方 API 适配器已经被关闭")
2 changes: 1 addition & 1 deletion astrbot/dashboard/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,4 @@ def run(self):

async def shutdown_trigger(self) -> None:
await self.shutdown_event.wait()
logger.info("AstrBot WebUI 已经被优雅地关闭")
logger.info("AstrBot WebUI 已经被关闭")
Loading