Skip to content

Commit 63ff234

Browse files
authored
feat: implement websockets transport mode selection for chat (#5410)
* feat: implement websockets transport mode selection for chat - Added transport mode selection (SSE/WebSocket) in the chat component. - Updated conversation sidebar to include transport mode options. - Integrated transport mode handling in message sending logic. - Refactored message sending functions to support both SSE and WebSocket. - Enhanced WebSocket connection management and message handling. - Updated localization files for transport mode labels. - Configured Vite to support WebSocket proxying. * feat(webchat): refactor message parsing logic and integrate new parsing function * feat(chat): add websocket API key extraction and scope validation
1 parent 5219ba5 commit 63ff234

File tree

14 files changed

+2264
-554
lines changed

14 files changed

+2264
-554
lines changed

astrbot/core/platform/sources/webchat/message_parts_helper.py

Lines changed: 465 additions & 0 deletions
Large diffs are not rendered by default.

astrbot/core/platform/sources/webchat/webchat_adapter.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import time
44
import uuid
55
from collections.abc import Callable, Coroutine
6+
from pathlib import Path
67
from typing import Any
78

89
from astrbot import logger
910
from astrbot.core import db_helper
1011
from astrbot.core.db.po import PlatformMessageHistory
11-
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
1212
from astrbot.core.message.message_event_result import MessageChain
1313
from astrbot.core.platform import (
1414
AstrBotMessage,
@@ -21,10 +21,23 @@
2121
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
2222

2323
from ...register import register_platform_adapter
24+
from .message_parts_helper import (
25+
message_chain_to_storage_message_parts,
26+
parse_webchat_message_parts,
27+
)
2428
from .webchat_event import WebChatMessageEvent
2529
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
2630

2731

32+
def _extract_conversation_id(session_id: str) -> str:
33+
"""Extract raw webchat conversation id from event/session id."""
34+
if session_id.startswith("webchat!"):
35+
parts = session_id.split("!", 2)
36+
if len(parts) == 3:
37+
return parts[2]
38+
return session_id
39+
40+
2841
class QueueListener:
2942
def __init__(
3043
self,
@@ -57,13 +70,15 @@ def __init__(
5770

5871
self.settings = platform_settings
5972
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
73+
self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
6074
os.makedirs(self.imgs_dir, exist_ok=True)
75+
self.attachments_dir.mkdir(parents=True, exist_ok=True)
6176

6277
self.metadata = PlatformMetadata(
6378
name="webchat",
6479
description="webchat",
6580
id="webchat",
66-
support_proactive_message=False,
81+
support_proactive_message=True,
6782
)
6883
self._shutdown_event = asyncio.Event()
6984
self._webchat_queue_mgr = webchat_queue_mgr
@@ -73,10 +88,67 @@ async def send_by_session(
7388
session: MessageSesion,
7489
message_chain: MessageChain,
7590
) -> None:
76-
message_id = f"active_{str(uuid.uuid4())}"
77-
await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
91+
conversation_id = _extract_conversation_id(session.session_id)
92+
active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
93+
conversation_id
94+
)
95+
subscription_request_ids = [
96+
req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
97+
]
98+
target_request_ids = subscription_request_ids or active_request_ids
99+
100+
if target_request_ids:
101+
for request_id in target_request_ids:
102+
await WebChatMessageEvent._send(
103+
request_id,
104+
message_chain,
105+
session.session_id,
106+
)
107+
else:
108+
message_id = f"active_{uuid.uuid4()!s}"
109+
await WebChatMessageEvent._send(
110+
message_id,
111+
message_chain,
112+
session.session_id,
113+
)
114+
115+
should_persist = (
116+
bool(subscription_request_ids)
117+
or not active_request_ids
118+
or all(req_id.startswith("active_") for req_id in active_request_ids)
119+
)
120+
if should_persist:
121+
try:
122+
await self._save_proactive_message(conversation_id, message_chain)
123+
except Exception as e:
124+
logger.error(
125+
f"[WebChatAdapter] Failed to save proactive message: {e}",
126+
exc_info=True,
127+
)
128+
78129
await super().send_by_session(session, message_chain)
79130

131+
async def _save_proactive_message(
132+
self,
133+
conversation_id: str,
134+
message_chain: MessageChain,
135+
) -> None:
136+
message_parts = await message_chain_to_storage_message_parts(
137+
message_chain,
138+
insert_attachment=db_helper.insert_attachment,
139+
attachments_dir=self.attachments_dir,
140+
)
141+
if not message_parts:
142+
return
143+
144+
await db_helper.insert_platform_message_history(
145+
platform_id="webchat",
146+
user_id=conversation_id,
147+
content={"type": "bot", "message": message_parts},
148+
sender_id="bot",
149+
sender_name="bot",
150+
)
151+
80152
async def _get_message_history(
81153
self, message_id: int
82154
) -> PlatformMessageHistory | None:
@@ -98,72 +170,30 @@ async def _parse_message_parts(
98170
Returns:
99171
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
100172
"""
101-
components = []
102-
text_parts = []
103-
104-
for part in message_parts:
105-
part_type = part.get("type")
106-
if part_type == "plain":
107-
text = part.get("text", "")
108-
components.append(Plain(text=text))
109-
text_parts.append(text)
110-
elif part_type == "reply":
111-
message_id = part.get("message_id")
112-
reply_chain = []
113-
reply_message_str = part.get("selected_text", "")
114-
sender_id = None
115-
sender_name = None
116-
117-
if reply_message_str:
118-
reply_chain = [Plain(text=reply_message_str)]
119-
120-
# recursively get the content of the referenced message, if selected_text is empty
121-
if not reply_message_str and depth < max_depth and message_id:
122-
history = await self._get_message_history(message_id)
123-
if history and history.content:
124-
reply_parts = history.content.get("message", [])
125-
if isinstance(reply_parts, list):
126-
(
127-
reply_chain,
128-
reply_text_parts,
129-
) = await self._parse_message_parts(
130-
reply_parts,
131-
depth=depth + 1,
132-
max_depth=max_depth,
133-
)
134-
reply_message_str = "".join(reply_text_parts)
135-
sender_id = history.sender_id
136-
sender_name = history.sender_name
137-
138-
components.append(
139-
Reply(
140-
id=message_id,
141-
chain=reply_chain,
142-
message_str=reply_message_str,
143-
sender_id=sender_id,
144-
sender_nickname=sender_name,
145-
)
146-
)
147-
elif part_type == "image":
148-
path = part.get("path")
149-
if path:
150-
components.append(Image.fromFileSystem(path))
151-
elif part_type == "record":
152-
path = part.get("path")
153-
if path:
154-
components.append(Record.fromFileSystem(path))
155-
elif part_type == "file":
156-
path = part.get("path")
157-
if path:
158-
filename = part.get("filename") or (
159-
os.path.basename(path) if path else "file"
160-
)
161-
components.append(File(name=filename, file=path))
162-
elif part_type == "video":
163-
path = part.get("path")
164-
if path:
165-
components.append(Video.fromFileSystem(path))
166173

174+
async def get_reply_parts(
175+
message_id: Any,
176+
) -> tuple[list[dict], str | None, str | None] | None:
177+
history = await self._get_message_history(message_id)
178+
if not history or not history.content:
179+
return None
180+
181+
reply_parts = history.content.get("message", [])
182+
if not isinstance(reply_parts, list):
183+
return None
184+
185+
return reply_parts, history.sender_id, history.sender_name
186+
187+
components, text_parts, _ = await parse_webchat_message_parts(
188+
message_parts,
189+
strict=False,
190+
include_empty_plain=True,
191+
verify_media_path_exists=False,
192+
reply_history_getter=get_reply_parts,
193+
current_depth=depth,
194+
max_reply_depth=max_depth,
195+
cast_reply_id_to_str=False,
196+
)
167197
return components, text_parts
168198

169199
async def convert_message(self, data: tuple) -> AstrBotMessage:

astrbot/core/platform/sources/webchat/webchat_event.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
1515

1616

17+
def _extract_conversation_id(session_id: str) -> str:
18+
"""Extract raw webchat conversation id from event/session id."""
19+
if session_id.startswith("webchat!"):
20+
parts = session_id.split("!", 2)
21+
if len(parts) == 3:
22+
return parts[2]
23+
return session_id
24+
25+
1726
class WebChatMessageEvent(AstrMessageEvent):
1827
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
1928
super().__init__(message_str, message_obj, platform_meta, session_id)
@@ -27,7 +36,7 @@ async def _send(
2736
streaming: bool = False,
2837
) -> str | None:
2938
request_id = str(message_id)
30-
conversation_id = session_id.split("!")[-1]
39+
conversation_id = _extract_conversation_id(session_id)
3140
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
3241
request_id,
3342
conversation_id,
@@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None:
130139
reasoning_content = ""
131140
message_id = self.message_obj.message_id
132141
request_id = str(message_id)
133-
conversation_id = self.session_id.split("!")[-1]
142+
conversation_id = _extract_conversation_id(self.session_id)
134143
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
135144
request_id,
136145
conversation_id,

astrbot/core/platform/sources/webchat/webchat_queue_mgr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str):
7575
if task is not None:
7676
task.cancel()
7777

78+
def list_back_request_ids(self, conversation_id: str) -> list[str]:
79+
"""List active back-queue request IDs for a conversation."""
80+
return list(self._conversation_back_requests.get(conversation_id, set()))
81+
7882
def has_queue(self, conversation_id: str) -> bool:
7983
"""Check if a queue exists for the given conversation ID"""
8084
return conversation_id in self.queues

0 commit comments

Comments
 (0)