diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..651721ec20 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -144,6 +144,7 @@ async def process( follow_up_capture: FollowUpCapture | None = None follow_up_consumed_marked = False follow_up_activated = False + typing_requested = False try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -178,7 +179,11 @@ async def process( ) return - await event.send_typing() + try: + typing_requested = True + await event.send_typing() + except Exception: + logger.warning("send_typing failed", exc_info=True) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): @@ -377,6 +382,11 @@ async def process( ) await event.send(MessageChain().message(error_text)) finally: + if typing_requested and not event.platform_meta.support_streaming_message: + try: + await event.stop_typing() + except Exception: + logger.warning("stop_typing failed", exc_info=True) if follow_up_capture: await finalize_follow_up_capture( follow_up_capture, diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 82c03dbb0d..0ecd47fedc 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -293,6 +293,12 @@ async def send_typing(self) -> None: 默认实现为空,由具体平台按需重写。 """ + async def stop_typing(self) -> None: + """停止输入中状态。 + + 默认实现为空,由具体平台按需重写。 + """ + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index c47b58087e..1e937b13c4 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -6,11 +6,12 @@ import io import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote +import aiohttp import qrcode as qrcode_lib from astrbot import logger @@ -49,6 +50,17 @@ class OpenClawLoginSession: error: str | None = None +@dataclass +class TypingSessionState: + ticket: str | None = None + ticket_context_token: str | None = None + refresh_after: float = 0.0 + keepalive_task: asyncio.Task | None = None + cancel_task: asyncio.Task | None = None + owners: set[str] = field(default_factory=set) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + @register_platform_adapter( "weixin_oc", "个人微信", @@ -105,7 +117,16 @@ def __init__( self._sync_buf = "" self._qr_expired_count = 0 self._context_tokens: dict[str, str] = {} + self._typing_states: dict[str, TypingSessionState] = {} self._last_inbound_error = "" + self._typing_keepalive_interval_s = max( + 1, + int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)), + ) + self._typing_ticket_ttl_s = max( + 5, + int(platform_config.get("weixin_oc_typing_ticket_ttl", 60)), + ) self.token = str(platform_config.get("weixin_oc_token", "")).strip() or None self.account_id = ( @@ -132,6 +153,316 @@ def _sync_client_state(self) -> None: self.client.api_timeout_ms = self.api_timeout_ms self.client.token = self.token + def _get_typing_state(self, user_id: str) -> TypingSessionState: + state = self._typing_states.get(user_id) + if state is None: + state = TypingSessionState() + self._typing_states[user_id] = state + return state + + def _typing_supported_for(self, user_id: str) -> bool: + if not self.token: + return False + return bool(self._context_tokens.get(user_id)) + + async def _cancel_task_safely( + self, + task: asyncio.Task | None, + *, + log_message: str | None = None, + log_args: tuple[Any, ...] = (), + ) -> None: + if task is None or task.done(): + return + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + if log_message is not None: + logger.warning(log_message, *log_args, exc_info=True) + + async def _ensure_typing_ticket( + self, + user_id: str, + state: TypingSessionState, + ) -> str | None: + now = time.monotonic() + context_token = self._context_tokens.get(user_id) + if not context_token: + return None + + if ( + state.ticket + and state.ticket_context_token == context_token + and state.refresh_after > now + ): + return state.ticket + + payload = await self.client.get_typing_config(user_id, context_token) + if int(payload.get("ret") or 0) != 0: + logger.warning( + "weixin_oc(%s): getconfig failed for %s: %s", + self.meta().id, + user_id, + payload.get("errmsg", ""), + ) + return None + + ticket = str(payload.get("typing_ticket", "")).strip() + if not ticket: + return None + + state.ticket = ticket + state.ticket_context_token = context_token + state.refresh_after = time.monotonic() + self._typing_ticket_ttl_s + return ticket + + async def _send_typing_state( + self, + user_id: str, + ticket: str, + *, + cancel: bool, + ) -> None: + payload = await self.client.send_typing_state(user_id, ticket, cancel=cancel) + if int(payload.get("ret") or 0) != 0: + raise RuntimeError( + f"sendtyping failed for {user_id}: {payload.get('errmsg', '')}" + ) + + async def _run_typing_keepalive(self, user_id: str) -> None: + restart_needed = False + try: + await self._typing_keepalive_loop(user_id) + except asyncio.CancelledError: + raise + except Exception as e: + state = self._typing_states.get(user_id) + if state is not None: + async with state.lock: + state.refresh_after = 0.0 + restart_needed = ( + bool(state.owners) and not self._shutdown_event.is_set() + ) + logger.warning( + "weixin_oc(%s): typing keepalive failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + current_task = asyncio.current_task() + if state is not None and state.keepalive_task is current_task: + state.keepalive_task = None + + if not restart_needed: + return + + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None or self._shutdown_event.is_set(): + return + + async with state.lock: + if not state.owners or state.keepalive_task is not None: + return + state.keepalive_task = asyncio.create_task( + self._run_typing_keepalive(user_id) + ) + + async def _typing_keepalive_loop(self, user_id: str) -> None: + while not self._shutdown_event.is_set(): + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None: + return + + async with state.lock: + if not state.owners: + return + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): refresh typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + continue + if not ticket: + continue + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): typing keepalive send failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + async def _delayed_cancel_typing(self, user_id: str, ticket: str) -> None: + await asyncio.sleep(0) + state = self._typing_states.get(user_id) + if state is None: + return + + current_task = asyncio.current_task() + async with state.lock: + if state.cancel_task is not current_task: + return + if state.owners or state.keepalive_task is not None: + state.cancel_task = None + return + + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "weixin_oc(%s): cancel typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + if state is None: + return + async with state.lock: + if state.cancel_task is current_task: + state.cancel_task = None + + async def start_typing(self, user_id: str, owner_id: str) -> None: + state = self._get_typing_state(user_id) + cancel_task: asyncio.Task | None = None + async with state.lock: + if owner_id in state.owners: + return + if not self._typing_supported_for(user_id): + return + if state.cancel_task is not None and not state.cancel_task.done(): + cancel_task = state.cancel_task + cancel_task.cancel() + state.cancel_task = None + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + logger.warning( + "weixin_oc(%s): ensure typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + return + if not ticket: + return + + state.ticket = ticket + state.owners.add(owner_id) + if state.keepalive_task is not None and not state.keepalive_task.done(): + return + + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): send typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + task = asyncio.create_task(self._run_typing_keepalive(user_id)) + state.keepalive_task = task + + if cancel_task is not None: + await self._cancel_task_safely( + cancel_task, + log_message="weixin_oc(%s): ignored error from cancelled typing task", + log_args=(self.meta().id,), + ) + + async def stop_typing(self, user_id: str, owner_id: str) -> None: + state = self._typing_states.get(user_id) + if state is None: + return + + task: asyncio.Task | None = None + async with state.lock: + if owner_id not in state.owners: + return + state.owners.remove(owner_id) + + if state.owners: + return + + task = state.keepalive_task + state.keepalive_task = None + + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing keepalive stop failed for %s", + log_args=(self.meta().id, user_id), + ) + + async with state.lock: + if state.owners: + return + ticket = state.ticket + if ticket: + if state.cancel_task is None or state.cancel_task.done(): + state.cancel_task = asyncio.create_task( + self._delayed_cancel_typing(user_id, ticket) + ) + + async def _cleanup_typing_tasks(self) -> None: + tasks: list[asyncio.Task] = [] + cancels: list[tuple[str, str]] = [] + for user_id, state in list(self._typing_states.items()): + if state.ticket and ( + state.owners + or state.keepalive_task is not None + or state.cancel_task is not None + ): + cancels.append((user_id, state.ticket)) + state.owners.clear() + if state.keepalive_task is not None and not state.keepalive_task.done(): + tasks.append(state.keepalive_task) + state.keepalive_task.cancel() + state.keepalive_task = None + if state.cancel_task is not None and not state.cancel_task.done(): + tasks.append(state.cancel_task) + state.cancel_task.cancel() + state.cancel_task = None + + for task in tasks: + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing cleanup failed", + log_args=(self.meta().id,), + ) + + for user_id, ticket in cancels: + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except Exception as e: + logger.warning( + "weixin_oc(%s): typing cleanup cancel failed for %s: %s", + self.meta().id, + user_id, + e, + ) + def _load_account_state(self) -> None: if not self.token: token = str(self.config.get("weixin_oc_token", "")).strip() @@ -902,15 +1233,26 @@ async def run(self) -> None: "weixin_oc(%s): inbound long-poll timeout", self.meta().id, ) + except aiohttp.ClientConnectionError as e: + self._last_inbound_error = str(e) or e.__class__.__name__ + logger.warning( + "weixin_oc(%s): inbound poll connection error: %s", + self.meta().id, + e, + ) + await self.client.close() + await asyncio.sleep(2) except asyncio.CancelledError: raise except Exception as e: logger.exception("weixin_oc(%s): run failed: %s", self.meta().id, e) finally: + await self._cleanup_typing_tasks() await self.client.close() async def terminate(self) -> None: self._shutdown_event.set() + await self._cleanup_typing_tasks() def get_stats(self) -> dict: stat = super().get_stats() diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py index 5ea30d911c..51b0b6ed7c 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py @@ -226,3 +226,44 @@ async def request_json( if not text: return {} return cast(dict[str, Any], json.loads(text)) + + async def get_typing_config( + self, + user_id: str, + context_token: str, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/getconfig", + payload={ + "ilink_user_id": user_id, + "context_token": context_token, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) + + async def send_typing_state( + self, + user_id: str, + typing_ticket: str, + *, + cancel: bool, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/sendtyping", + payload={ + "ilink_user_id": user_id, + "typing_ticket": typing_ticket, + "status": 2 if cancel else 1, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py index abe3b5a066..84a19a9e7b 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import uuid from typing import TYPE_CHECKING from astrbot.api.event import AstrMessageEvent, MessageChain @@ -29,6 +30,12 @@ def __init__( ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.platform = platform + self._typing_owner_id: str | None = None + + def _get_typing_owner_id(self) -> str: + if not self._typing_owner_id: + self._typing_owner_id = uuid.uuid4().hex + return self._typing_owner_id @staticmethod def _segment_to_text(segment: BaseMessageComponent) -> str: @@ -58,6 +65,18 @@ async def send(self, message: MessageChain) -> None: await self.platform.send_by_session(self.session, message) await super().send(message) + async def send_typing(self) -> None: + await self.platform.start_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + + async def stop_typing(self) -> None: + await self.platform.stop_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + async def send_streaming(self, generator, use_fallback: bool = False): if not use_fallback: buffer = None diff --git a/tests/unit/test_astr_message_event.py b/tests/unit/test_astr_message_event.py index ac529318fe..89087d1cab 100644 --- a/tests/unit/test_astr_message_event.py +++ b/tests/unit/test_astr_message_event.py @@ -651,6 +651,15 @@ async def test_send_typing_default_empty(self, astr_message_event): await astr_message_event.send_typing() +class TestStopTyping: + """Tests for stop_typing method.""" + + @pytest.mark.asyncio + async def test_stop_typing_default_empty(self, astr_message_event): + """Test stop_typing default implementation is empty.""" + await astr_message_event.stop_typing() + + class TestReact: """Tests for react method.""" @@ -772,10 +781,12 @@ def test_get_sender_fields_without_sender_attr(self, astr_message_event): def test_get_message_type_with_non_enum_type(self, astr_message_event): """get_message_type should handle message_obj.type that is not a MessageType.""" + class DummyMessage: def __init__(self): self.type = "not_an_enum" self.message = [] + astr_message_event.message_obj = DummyMessage() message_type = astr_message_event.get_message_type() assert isinstance(message_type, MessageType) diff --git a/tests/unit/test_internal_agent_sub_stage.py b/tests/unit/test_internal_agent_sub_stage.py new file mode 100644 index 0000000000..69c4b1c296 --- /dev/null +++ b/tests/unit/test_internal_agent_sub_stage.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.message.components import Plain +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.pipeline.process_stage.method.agent_sub_stages import ( + internal as internal_module, +) +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, +) + + +class ConcreteAstrMessageEvent(AstrMessageEvent): + async def send(self, message): + await super().send(message) + + +@pytest.fixture +def mock_ctx(): + plugin_context = MagicMock() + plugin_context.conversation_manager = MagicMock() + plugin_context.get_config.return_value = {"timezone": "UTC"} + plugin_context.get_using_tts_provider.return_value = None + + ctx = MagicMock() + ctx.astrbot_config = { + "provider_settings": { + "streaming_response": False, + "unsupported_streaming_strategy": "turn_off", + "max_context_length": 32, + "dequeue_context_length": 4, + }, + "kb_agentic_mode": False, + "subagent_orchestrator": {}, + } + ctx.plugin_manager.context = plugin_context + return ctx + + +@pytest.fixture +def stage(mock_ctx): + async def _make_stage(): + obj = InternalAgentSubStage() + await obj.initialize(mock_ctx) + obj._save_to_history = AsyncMock() + return obj + + return _make_stage + + +@pytest.fixture +def event(): + platform_meta = PlatformMetadata( + name="test_platform", + description="Test platform", + id="test_platform_id", + support_streaming_message=False, + ) + message = AstrBotMessage() + message.type = MessageType.FRIEND_MESSAGE + message.self_id = "bot123" + message.session_id = "session123" + message.message_id = "msg123" + message.sender = MessageMember(user_id="user123", nickname="TestUser") + message.message = [Plain(text="Hello world")] + message.message_str = "Hello world" + message.raw_message = None + return ConcreteAstrMessageEvent( + message_str="Hello world", + message_obj=message, + platform_meta=platform_meta, + session_id="session123", + ) + + +@asynccontextmanager +async def fake_lock(_umo): + yield + + +def make_build_result() -> SimpleNamespace: + provider = MagicMock() + provider.provider_config = {"id": "provider-1", "api_base": ""} + provider.get_model.return_value = "test-model" + provider.meta.return_value = SimpleNamespace(type="test") + + final_resp = SimpleNamespace( + completion_text="done", + result_chain=None, + role="assistant", + usage=None, + ) + agent_runner = MagicMock() + agent_runner.done.return_value = True + agent_runner.was_aborted.return_value = False + agent_runner.get_final_llm_resp.return_value = final_resp + agent_runner.run_context = SimpleNamespace(messages=[]) + agent_runner.stats = MagicMock() + agent_runner.stats.to_dict.return_value = {} + agent_runner.provider = provider + + return SimpleNamespace( + agent_runner=agent_runner, + provider_request=SimpleNamespace( + system_prompt="sys", + func_tool=None, + conversation=object(), + tool_calls_result=None, + ), + provider=provider, + reset_coro=None, + ) + + +async def empty_run_agent(*args, **kwargs): + if False: + yield None + + +@pytest.mark.asyncio +async def test_process_swallows_send_typing_error_and_still_releases(stage, event): + event.send_typing = AsyncMock(side_effect=RuntimeError("boom")) + event.stop_typing = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module.logger, "warning") as warning_mock, + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + warning_mock.assert_called_once_with("send_typing failed", exc_info=True) + + +@pytest.mark.asyncio +async def test_process_releases_typing_when_build_returns_none(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send_typing.assert_awaited_once() + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_releases_typing_when_llm_request_hook_short_circuits( + stage, event +): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, True]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_releases_typing_after_normal_reply(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, False]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + patch.object(internal_module, "run_agent", empty_run_agent), + patch.object(internal_module, "register_active_runner"), + patch.object(internal_module, "unregister_active_runner"), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_does_not_stop_typing_early_for_streaming_platforms(stage, event): + event.platform_meta.support_streaming_message = True + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + obj.streaming_response = True + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, False]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + patch.object(internal_module, "run_agent", empty_run_agent), + patch.object(internal_module, "register_active_runner"), + patch.object(internal_module, "unregister_active_runner"), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert len(results) == 1 + event.stop_typing.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_process_releases_typing_on_error_fallback_send(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + event.send = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(side_effect=RuntimeError("boom")), + ), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send.assert_awaited_once() + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_swallows_stop_typing_error(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock(side_effect=RuntimeError("stop failed")) + obj = await stage() + + with ( + patch.object(internal_module.logger, "warning") as warning_mock, + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send_typing.assert_awaited_once() + event.stop_typing.assert_awaited_once() + warning_mock.assert_called_once_with("stop_typing failed", exc_info=True) diff --git a/tests/unit/test_weixin_oc_typing.py b/tests/unit/test_weixin_oc_typing.py new file mode 100644 index 0000000000..c04cdc2144 --- /dev/null +++ b/tests/unit/test_weixin_oc_typing.py @@ -0,0 +1,604 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from astrbot.core.message.components import Plain +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter import ( + TypingSessionState, + WeixinOCAdapter, +) +from astrbot.core.platform.sources.weixin_oc.weixin_oc_client import WeixinOCClient +from astrbot.core.platform.sources.weixin_oc.weixin_oc_event import WeixinOCMessageEvent + + +@pytest.fixture +def client(): + return WeixinOCClient( + adapter_id="wx-1", + base_url="https://example.com", + cdn_base_url="https://cdn.example.com", + api_timeout_ms=15000, + token="token-1", + ) + + +@pytest.fixture +def adapter(): + obj = WeixinOCAdapter( + platform_config={ + "id": "wx-1", + "type": "weixin_oc", + "weixin_oc_token": "token-1", + }, + platform_settings={}, + event_queue=asyncio.Queue(), + ) + obj._context_tokens["user-1"] = "ctx-1" + return obj + + +@pytest.fixture +def weixin_event(): + message = AstrBotMessage() + message.type = MessageType.FRIEND_MESSAGE + message.self_id = "bot123" + message.session_id = "user-1" + message.message_id = "msg123" + message.sender = MessageMember(user_id="user-1", nickname="User") + message.message = [Plain(text="hello")] + message.message_str = "hello" + message.raw_message = None + + platform = MagicMock() + platform.start_typing = AsyncMock() + platform.stop_typing = AsyncMock() + platform.send_by_session = AsyncMock() + + event = WeixinOCMessageEvent( + message_str="hello", + message_obj=message, + platform_meta=PlatformMetadata( + name="weixin_oc", + description="个人微信", + id="wx-1", + support_streaming_message=False, + ), + session_id="user-1", + platform=platform, + ) + return event, platform + + +@pytest.mark.asyncio +async def test_get_typing_config_uses_getconfig(client): + client.request_json = AsyncMock(return_value={"typing_ticket": "ticket-1"}) + + result = await client.get_typing_config("user-1", "ctx-1") + + assert result == {"typing_ticket": "ticket-1"} + client.request_json.assert_awaited_once_with( + "POST", + "ilink/bot/getconfig", + payload={ + "ilink_user_id": "user-1", + "context_token": "ctx-1", + "base_info": {"channel_version": "astrbot"}, + }, + token_required=True, + timeout_ms=client.api_timeout_ms, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cancel, status", [(False, 1), (True, 2)]) +async def test_send_typing_state_uses_sendtyping(client, cancel, status): + client.request_json = AsyncMock(return_value={}) + + await client.send_typing_state("user-1", "ticket-1", cancel=cancel) + + client.request_json.assert_awaited_once_with( + "POST", + "ilink/bot/sendtyping", + payload={ + "ilink_user_id": "user-1", + "typing_ticket": "ticket-1", + "status": status, + "base_info": {"channel_version": "astrbot"}, + }, + token_required=True, + timeout_ms=client.api_timeout_ms, + ) + + +@pytest.mark.asyncio +async def test_event_delegates_typing_calls(weixin_event): + event, platform = weixin_event + + await event.send_typing() + await event.stop_typing() + + platform.start_typing.assert_awaited_once() + platform.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_event_reuses_stable_owner_id(weixin_event): + event, platform = weixin_event + + await event.send_typing() + await event.stop_typing() + + start_owner = platform.start_typing.await_args.args[1] + stop_owner = platform.stop_typing.await_args.args[1] + assert start_owner == stop_owner + + +@pytest.mark.asyncio +async def test_start_typing_skips_without_token(adapter): + adapter.token = None + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + adapter._send_typing_state = AsyncMock() + + await adapter.start_typing("user-1", "owner-a") + + adapter._ensure_typing_ticket.assert_not_awaited() + adapter._send_typing_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_start_typing_skips_without_context_token(adapter): + adapter._context_tokens.clear() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + adapter._send_typing_state = AsyncMock() + + await adapter.start_typing("user-1", "owner-a") + + adapter._ensure_typing_ticket.assert_not_awaited() + adapter._send_typing_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_reuses_fresh_ticket(adapter): + state = TypingSessionState( + ticket="cached-ticket", + ticket_context_token="ctx-1", + refresh_after=float("inf"), + ) + adapter.client.get_typing_config = AsyncMock() + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "cached-ticket" + adapter.client.get_typing_config.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_refreshes_stale_ticket(adapter): + state = TypingSessionState(ticket="stale-ticket", refresh_after=0.0) + adapter.client.get_typing_config = AsyncMock( + return_value={"typing_ticket": "fresh-ticket"} + ) + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "fresh-ticket" + assert state.ticket == "fresh-ticket" + adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-1") + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_refreshes_when_context_token_changes(adapter): + state = TypingSessionState( + ticket="cached-ticket", + ticket_context_token="ctx-1", + refresh_after=float("inf"), + ) + adapter._context_tokens["user-1"] = "ctx-2" + adapter.client.get_typing_config = AsyncMock( + return_value={"typing_ticket": "fresh-ticket"} + ) + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "fresh-ticket" + assert state.ticket_context_token == "ctx-2" + adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-2") + + +@pytest.mark.asyncio +async def test_send_typing_state_raises_on_nonzero_ret(adapter): + adapter.client.send_typing_state = AsyncMock( + return_value={"ret": 1, "errmsg": "expired"} + ) + + with pytest.raises(RuntimeError, match="sendtyping failed"): + await adapter._send_typing_state("user-1", "ticket-1", cancel=False) + + +@pytest.mark.asyncio +async def test_cancel_task_safely_logs_task_errors(adapter): + async def failing_task(): + try: + await asyncio.Event().wait() + except asyncio.CancelledError as exc: + raise RuntimeError("task wait failed") from exc + + task = asyncio.create_task(failing_task()) + await asyncio.sleep(0) + + with patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" + ) as warning_mock: + await adapter._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing cleanup failed", + log_args=(adapter.meta().id,), + ) + + warning_mock.assert_called_once_with( + "weixin_oc(%s): typing cleanup failed", + adapter.meta().id, + exc_info=True, + ) + + +@pytest.mark.asyncio +async def test_start_typing_same_owner_is_idempotent(adapter): + stop_event = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.start_typing("user-1", "owner-a") + + assert adapter._send_typing_state.await_count == 1 + state = adapter._typing_states["user-1"] + assert state.owners == {"owner-a"} + + await adapter.stop_typing("user-1", "owner-a") + stop_event.set() + + +@pytest.mark.asyncio +async def test_stop_typing_only_cancels_on_last_owner(adapter): + stop_event = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.start_typing("user-1", "owner-b") + await adapter.stop_typing("user-1", "owner-a") + + state = adapter._typing_states["user-1"] + assert state.owners == {"owner-b"} + assert adapter._send_typing_state.await_count == 1 + + await adapter.stop_typing("user-1", "owner-b") + await asyncio.sleep(0) + await asyncio.sleep(0) + stop_event.set() + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_stop_typing_is_safe_to_repeat(adapter): + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await asyncio.Event().wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_keepalive_failure_cleans_state(adapter): + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + raise RuntimeError("keepalive failed") + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await asyncio.sleep(0) + + state = adapter._typing_states["user-1"] + assert state.keepalive_task is None + + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_keepalive_failure_restarts_for_active_owner(adapter): + adapter._typing_keepalive_interval_s = 0 + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + keepalive_round = 0 + stop_event = asyncio.Event() + + async def fake_keepalive(_user_id): + nonlocal keepalive_round + keepalive_round += 1 + if keepalive_round == 1: + raise RuntimeError("keepalive failed") + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + for _ in range(4): + await asyncio.sleep(0) + + state = adapter._typing_states["user-1"] + assert keepalive_round >= 2 + assert state.keepalive_task is not None + + stop_event.set() + await adapter.stop_typing("user-1", "owner-a") + for _ in range(2): + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_stop_typing_does_not_cancel_new_owner_session(adapter): + cancel_blocked = asyncio.Event() + allow_cancel_exit = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancel_blocked.set() + await allow_cancel_exit.wait() + raise + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + stop_task = asyncio.create_task(adapter.stop_typing("user-1", "owner-a")) + await cancel_blocked.wait() + await adapter.start_typing("user-1", "owner-b") + allow_cancel_exit.set() + await stop_task + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_start_typing_cancels_inflight_cancel_task(adapter): + cancel_started = asyncio.Event() + release_cancel = asyncio.Event() + stop_event = asyncio.Event() + events: list[str] = [] + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_send_typing_state(_user_id, ticket, *, cancel): + if cancel: + events.append("cancel-start") + cancel_started.set() + try: + await release_cancel.wait() + except asyncio.CancelledError: + events.append("cancel-cancelled") + raise + events.append("cancel-finished") + return + events.append(f"start-{ticket}") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._send_typing_state = fake_send_typing_state + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + await cancel_started.wait() + + start_task = asyncio.create_task(adapter.start_typing("user-1", "owner-b")) + await asyncio.sleep(0) + release_cancel.set() + await start_task + + assert "cancel-cancelled" in events + assert "cancel-finished" not in events + + stop_event.set() + await adapter.stop_typing("user-1", "owner-b") + await asyncio.sleep(0) + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_start_typing_logs_ignored_cancel_task_errors(adapter): + stop_event = asyncio.Event() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + state = adapter._get_typing_state("user-1") + + async def fake_send_typing_state(_user_id, _ticket, *, cancel): + return None + + async def fake_cancel_task(): + try: + await asyncio.Event().wait() + except asyncio.CancelledError as exc: + raise RuntimeError("cancel failed") from exc + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._send_typing_state = fake_send_typing_state + adapter._typing_keepalive_loop = fake_keepalive + state.cancel_task = asyncio.create_task(fake_cancel_task()) + await asyncio.sleep(0) + + with patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" + ) as warning_mock: + await adapter.start_typing("user-1", "owner-a") + + warning_mock.assert_called_once_with( + "weixin_oc(%s): ignored error from cancelled typing task", + adapter.meta().id, + exc_info=True, + ) + + stop_event.set() + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_cleanup_typing_tasks_sends_final_cancel(adapter): + adapter._send_typing_state = AsyncMock() + + async def fake_keepalive(_user_id): + await asyncio.Event().wait() + + task = asyncio.create_task(fake_keepalive("user-1")) + adapter._typing_states["user-1"] = TypingSessionState( + ticket="ticket-1", + refresh_after=float("inf"), + keepalive_task=task, + owners={"owner-a"}, + ) + + await adapter._cleanup_typing_tasks() + + adapter._send_typing_state.assert_awaited_once_with( + "user-1", + "ticket-1", + cancel=True, + ) + + +@pytest.mark.asyncio +async def test_run_finally_cancels_keepalive_before_client_close(adapter): + order: list[str] = [] + task = asyncio.create_task(asyncio.Event().wait()) + adapter._typing_states["user-1"] = TypingSessionState( + ticket="ticket-1", + refresh_after=float("inf"), + keepalive_task=task, + owners={"owner-a"}, + ) + adapter._cleanup_typing_tasks = AsyncMock( + side_effect=lambda: order.append("cleanup") + ) + adapter.client.close = AsyncMock(side_effect=lambda: order.append("close")) + + with patch.object( + adapter, + "_poll_inbound_updates", + AsyncMock(side_effect=RuntimeError("boom")), + ): + await adapter.run() + + assert order == ["cleanup", "close"] + + +@pytest.mark.asyncio +async def test_run_recovers_after_server_disconnect(adapter): + poll_count = 0 + adapter._cleanup_typing_tasks = AsyncMock() + adapter.client.close = AsyncMock() + + async def fake_poll(): + nonlocal poll_count + poll_count += 1 + if poll_count == 1: + raise aiohttp.ServerDisconnectedError() + assert adapter.client.close.await_count == 1 + adapter._shutdown_event.set() + + with ( + patch.object(adapter, "_poll_inbound_updates", side_effect=fake_poll), + patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.asyncio.sleep", + new_callable=AsyncMock, + ) as sleep_mock, + patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" + ) as warning_mock, + ): + await adapter.run() + + assert poll_count == 2 + warning_mock.assert_called_once() + assert adapter.client.close.await_count == 2 + sleep_mock.assert_awaited_once_with(2) + assert adapter._last_inbound_error + assert ( + "Server disconnected" in adapter._last_inbound_error + or "ServerDisconnectedError" in adapter._last_inbound_error + ) + + +@pytest.mark.asyncio +async def test_run_keeps_non_network_poll_errors_fatal(adapter): + poll_mock = AsyncMock(side_effect=RuntimeError("boom")) + adapter._cleanup_typing_tasks = AsyncMock() + adapter.client.close = AsyncMock() + + with ( + patch.object(adapter, "_poll_inbound_updates", poll_mock), + patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.exception" + ) as exception_mock, + ): + await adapter.run() + + assert poll_mock.await_count == 1 + adapter._cleanup_typing_tasks.assert_awaited_once() + adapter.client.close.assert_awaited_once() + exception_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_still_works_with_existing_event_behavior(weixin_event): + event, platform = weixin_event + + with patch( + "astrbot.core.platform.astr_message_event.Metric.upload", + new_callable=AsyncMock, + ): + await event.send(MessageChain([Plain("reply")])) + + platform.send_by_session.assert_awaited_once()