Skip to content

Commit bf3fa3e

Browse files
Trance-0Soulter
andauthored
fix: 改进微信公众号被动回复处理机制,引入缓冲与分片回复,并优化超时行为 (#5224)
* 修复wechat official 被动回复功能 * ruff format --------- Co-authored-by: Soulter <905617992@qq.com>
1 parent 3b2ce9f commit bf3fa3e

File tree

2 files changed

+215
-38
lines changed

2 files changed

+215
-38
lines changed

astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py

Lines changed: 197 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
22
import os
33
import sys
4+
import time
45
import uuid
56
from collections.abc import Awaitable, Callable
67
from typing import Any, cast
78

89
import quart
910
from requests import Response
10-
from wechatpy import WeChatClient, parse_message
11+
from wechatpy import WeChatClient, create_reply, parse_message
1112
from wechatpy.crypto import WeChatCrypto
1213
from wechatpy.exceptions import InvalidSignatureException
1314
from wechatpy.messages import BaseMessage, ImageMessage, TextMessage, VoiceMessage
@@ -38,7 +39,12 @@
3839

3940

4041
class WeixinOfficialAccountServer:
41-
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
42+
def __init__(
43+
self,
44+
event_queue: asyncio.Queue,
45+
config: dict,
46+
user_buffer: dict[Any, dict[str, Any]],
47+
) -> None:
4248
self.server = quart.Quart(__name__)
4349
self.port = int(cast(int | str, config.get("port")))
4450
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
@@ -62,6 +68,10 @@ def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
6268
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
6369
self.shutdown_event = asyncio.Event()
6470

71+
self._wx_msg_time_out = 4.0 # 微信服务器要求 5 秒内回复
72+
self.user_buffer: dict[str, dict[str, Any]] = user_buffer # from_user -> state
73+
self.active_send_mode = False # 是否启用主动发送模式,启用后 callback 将直接返回回复内容,无需等待微信回调
74+
6575
async def verify(self):
6676
"""内部服务器的 GET 验证入口"""
6777
return await self.handle_verify(quart.request)
@@ -98,6 +108,22 @@ async def callback_command(self):
98108
"""内部服务器的 POST 回调入口"""
99109
return await self.handle_callback(quart.request)
100110

111+
def _maybe_encrypt(self, xml: str, nonce: str | None, timestamp: str | None) -> str:
112+
if xml and "<Encrypt>" not in xml and nonce and timestamp:
113+
return self.crypto.encrypt_message(xml, nonce, timestamp)
114+
return xml or "success"
115+
116+
def _preview(self, msg: BaseMessage, limit: int = 24) -> str:
117+
"""生成消息预览文本,供占位符使用"""
118+
if isinstance(msg, TextMessage):
119+
t = cast(str, msg.content).strip()
120+
return (t[:limit] + "...") if len(t) > limit else (t or "空消息")
121+
if isinstance(msg, ImageMessage):
122+
return "图片"
123+
if isinstance(msg, VoiceMessage):
124+
return "语音"
125+
return getattr(msg, "type", "未知消息")
126+
101127
async def handle_callback(self, request) -> str:
102128
"""处理回调请求,可被统一 webhook 入口复用
103129
@@ -123,14 +149,152 @@ async def handle_callback(self, request) -> str:
123149
raise
124150
logger.info(f"解析成功: {msg}")
125151

126-
if self.callback:
152+
if not self.callback:
153+
return "success"
154+
155+
# by pass passive reply logic and return active reply directly.
156+
if self.active_send_mode:
127157
result_xml = await self.callback(msg)
128158
if not result_xml:
129159
return "success"
130160
if isinstance(result_xml, str):
131161
return result_xml
132162

133-
return "success"
163+
# passive reply
164+
from_user = str(getattr(msg, "source", ""))
165+
msg_id = str(cast(str | int, getattr(msg, "id", "")))
166+
state = self.user_buffer.get(from_user)
167+
168+
def _reply_text(text: str) -> str:
169+
reply_obj = create_reply(text, msg)
170+
reply_xml = reply_obj if isinstance(reply_obj, str) else str(reply_obj)
171+
return self._maybe_encrypt(reply_xml, nonce, timestamp)
172+
173+
# if in cached state, return cached result or placeholder
174+
if state:
175+
logger.debug(f"用户消息缓冲状态: user={from_user} state={state}")
176+
cached = state.get("cached_xml")
177+
# send one cached each time, if cached is empty after pop, remove the buffer
178+
if cached and len(cached) > 0:
179+
logger.info(f"wx buffer hit on trigger: user={from_user}")
180+
cached_xml = cached.pop(0)
181+
if len(cached) == 0:
182+
self.user_buffer.pop(from_user, None)
183+
return _reply_text(cached_xml)
184+
else:
185+
return _reply_text(
186+
cached_xml
187+
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
188+
)
189+
190+
task: asyncio.Task | None = cast(asyncio.Task | None, state.get("task"))
191+
placeholder = (
192+
f"【正在思考'{state.get('preview', '...')}'中,已思考"
193+
f"{int(time.monotonic() - state.get('started_at', time.monotonic()))}s,回复任意文字尝试获取回复】"
194+
)
195+
196+
# same msgid => WeChat retry: wait a little; new msgid => user trigger: just placeholder
197+
if task and state.get("msg_id") == msg_id:
198+
done, _ = await asyncio.wait(
199+
{task},
200+
timeout=self._wx_msg_time_out,
201+
return_when=asyncio.FIRST_COMPLETED,
202+
)
203+
if done:
204+
try:
205+
cached = state.get("cached_xml")
206+
# send one cached each time, if cached is empty after pop, remove the buffer
207+
if cached and len(cached) > 0:
208+
logger.info(
209+
f"wx buffer hit on retry window: user={from_user}"
210+
)
211+
cached_xml = cached.pop(0)
212+
if len(cached) == 0:
213+
self.user_buffer.pop(from_user, None)
214+
logger.debug(
215+
f"wx finished message sending in passive window: user={from_user} msg_id={msg_id} "
216+
)
217+
return _reply_text(cached_xml)
218+
else:
219+
logger.debug(
220+
f"wx finished message sending in passive window but not final: user={from_user} msg_id={msg_id} "
221+
)
222+
return _reply_text(
223+
cached_xml
224+
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
225+
)
226+
logger.info(
227+
f"wx finished in window but not final; return placeholder: user={from_user} msg_id={msg_id} "
228+
)
229+
return _reply_text(placeholder)
230+
except Exception:
231+
logger.critical(
232+
"wx task failed in passive window", exc_info=True
233+
)
234+
self.user_buffer.pop(from_user, None)
235+
return _reply_text("处理消息失败,请稍后再试。")
236+
237+
logger.info(
238+
f"wx passive window timeout: user={from_user} msg_id={msg_id}"
239+
)
240+
return _reply_text(placeholder)
241+
242+
logger.debug(f"wx trigger while thinking: user={from_user}")
243+
return _reply_text(placeholder)
244+
245+
# create new trigger when state is empty, and store state in buffer
246+
logger.debug(f"wx new trigger: user={from_user} msg_id={msg_id}")
247+
preview = self._preview(msg)
248+
placeholder = (
249+
f"【正在思考'{preview}'中,已思考0s,回复任意文字尝试获取回复】"
250+
)
251+
logger.info(
252+
f"wx start task: user={from_user} msg_id={msg_id} preview={preview}"
253+
)
254+
255+
self.user_buffer[from_user] = state = {
256+
"msg_id": msg_id,
257+
"preview": preview,
258+
"task": None, # set later after task created
259+
"cached_xml": [], # for passive reply
260+
"started_at": time.monotonic(),
261+
}
262+
self.user_buffer[from_user]["task"] = task = asyncio.create_task(
263+
self.callback(msg)
264+
)
265+
266+
# immediate return if done
267+
done, _ = await asyncio.wait(
268+
{task},
269+
timeout=self._wx_msg_time_out,
270+
return_when=asyncio.FIRST_COMPLETED,
271+
)
272+
if done:
273+
try:
274+
cached = state.get("cached_xml", None)
275+
# send one cached each time, if cached is empty after pop, remove the buffer
276+
if cached and len(cached) > 0:
277+
logger.info(f"wx buffer hit immediately: user={from_user}")
278+
cached_xml = cached.pop(0)
279+
if len(cached) == 0:
280+
self.user_buffer.pop(from_user, None)
281+
return _reply_text(cached_xml)
282+
else:
283+
return _reply_text(
284+
cached_xml
285+
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
286+
)
287+
logger.info(
288+
f"wx not finished in first window; return placeholder: user={from_user} msg_id={msg_id} "
289+
)
290+
return _reply_text(placeholder)
291+
except Exception:
292+
logger.critical("wx task failed in first window", exc_info=True)
293+
self.user_buffer.pop(from_user, None)
294+
return _reply_text("处理消息失败,请稍后再试。")
295+
296+
logger.info(f"wx first window timeout: user={from_user} msg_id={msg_id}")
297+
return _reply_text(placeholder)
134298

135299
async def start_polling(self) -> None:
136300
logger.info(
@@ -176,7 +340,10 @@ def __init__(
176340
if not self.api_base_url.endswith("/"):
177341
self.api_base_url += "/"
178342

179-
self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
343+
self.user_buffer: dict[str, dict[str, Any]] = {} # from_user -> state
344+
self.server = WeixinOfficialAccountServer(
345+
self._event_queue, self.config, self.user_buffer
346+
)
180347

181348
self.client = WeChatClient(
182349
self.config["appid"].strip(),
@@ -193,28 +360,33 @@ async def callback(msg: BaseMessage):
193360
try:
194361
if self.active_send_mode:
195362
await self.convert_message(msg, None)
363+
return None
364+
365+
msg_id = str(cast(str | int, msg.id))
366+
future = self.wexin_event_workers.get(msg_id)
367+
if future:
368+
logger.debug(f"duplicate message id checked: {msg.id}")
196369
else:
197-
if str(msg.id) in self.wexin_event_workers:
198-
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
199-
logger.debug(f"duplicate message id checked: {msg.id}")
200-
else:
201-
future = asyncio.get_event_loop().create_future()
202-
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
203-
await self.convert_message(msg, future)
370+
future = asyncio.get_event_loop().create_future()
371+
self.wexin_event_workers[msg_id] = future
372+
await self.convert_message(msg, future)
204373
# I love shield so much!
205374
result = await asyncio.wait_for(
206375
asyncio.shield(future),
207-
60,
208-
) # wait for 60s
209-
logger.debug(f"Got future result: {result}")
210-
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
211-
return result # xml. see weixin_offacc_event.py
376+
180,
377+
) # wait for 180s
378+
logger.debug(f"Got future result: {result}")
379+
return result
212380
except asyncio.TimeoutError:
213-
pass
381+
logger.info(f"callback 处理消息超时: message_id={msg.id}")
382+
return create_reply("处理消息超时,请稍后再试。", msg)
214383
except Exception as e:
215384
logger.error(f"转换消息时出现异常: {e}")
385+
finally:
386+
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
216387

217388
self.server.callback = callback
389+
self.server.active_send_mode = self.active_send_mode
218390

219391
@override
220392
async def send_by_session(
@@ -336,12 +508,19 @@ async def convert_message(
336508
await self.handle_msg(abm)
337509

338510
async def handle_msg(self, message: AstrBotMessage) -> None:
511+
buffer = self.user_buffer.get(message.sender.user_id, None)
512+
if buffer is None:
513+
logger.critical(
514+
f"用户消息未找到缓冲状态,无法处理消息: user={message.sender.user_id} message_id={message.message_id}"
515+
)
516+
return
339517
message_event = WeixinOfficialAccountPlatformEvent(
340518
message_str=message.message_str,
341519
message_obj=message,
342520
platform_meta=self.meta(),
343521
session_id=message.session_id,
344522
client=self.client,
523+
message_out=buffer,
345524
)
346525
self.commit_event(message_event)
347526

astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
22
import os
3-
from typing import cast
3+
from typing import Any, cast
44

55
from wechatpy import WeChatClient
6-
from wechatpy.replies import ImageReply, TextReply, VoiceReply
6+
from wechatpy.replies import ImageReply, VoiceReply
77

88
from astrbot.api import logger
99
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -20,9 +20,11 @@ def __init__(
2020
platform_meta: PlatformMetadata,
2121
session_id: str,
2222
client: WeChatClient,
23+
message_out: dict[Any, Any],
2324
) -> None:
2425
super().__init__(message_str, message_obj, platform_meta, session_id)
2526
self.client = client
27+
self.message_out = message_out
2628

2729
@staticmethod
2830
async def send_with_client(
@@ -32,27 +34,27 @@ async def send_with_client(
3234
) -> None:
3335
pass
3436

35-
async def split_plain(self, plain: str) -> list[str]:
36-
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
37+
async def split_plain(self, plain: str, max_length: int = 1024) -> list[str]:
38+
"""将长文本分割成多个小文本, 每个小文本长度不超过 max_length 字符
3739
3840
Args:
3941
plain (str): 要分割的长文本
4042
Returns:
4143
list[str]: 分割后的文本列表
4244
4345
"""
44-
if len(plain) <= 2048:
46+
if len(plain) <= max_length:
4547
return [plain]
4648
result = []
4749
start = 0
4850
while start < len(plain):
49-
# 剩下的字符串长度<2048时结束
50-
if start + 2048 >= len(plain):
51+
# 剩下的字符串长度<max_length时结束
52+
if start + max_length >= len(plain):
5153
result.append(plain[start:])
5254
break
5355

5456
# 向前搜索分割标点符号
55-
end = min(start + 2048, len(plain))
57+
end = min(start + max_length, len(plain))
5658
cut_position = end
5759
for i in range(end, start, -1):
5860
if i < len(plain) and plain[i - 1] in [
@@ -87,19 +89,15 @@ async def send(self, message: MessageChain) -> None:
8789
if isinstance(comp, Plain):
8890
# Split long text messages if needed
8991
plain_chunks = await self.split_plain(comp.text)
90-
for chunk in plain_chunks:
91-
if active_send_mode:
92+
if active_send_mode:
93+
for chunk in plain_chunks:
9294
self.client.message.send_text(message_obj.sender.user_id, chunk)
93-
else:
94-
reply = TextReply(
95-
content=chunk,
96-
message=cast(dict, self.message_obj.raw_message)["message"],
97-
)
98-
xml = reply.render()
99-
future = cast(dict, self.message_obj.raw_message)["future"]
100-
assert isinstance(future, asyncio.Future)
101-
future.set_result(xml)
102-
await asyncio.sleep(0.5) # Avoid sending too fast
95+
else:
96+
# disable passive sending, just store the chunks in
97+
logger.debug(
98+
f"split plain into {len(plain_chunks)} chunks for passive reply. Message not sent."
99+
)
100+
self.message_out["cached_xml"] = plain_chunks
103101
elif isinstance(comp, Image):
104102
img_path = await comp.convert_to_file_path()
105103

0 commit comments

Comments
 (0)