Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions astrbot/core/platform/sources/slack/session_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
THREAD_SESSION_MARKER = "__thread__"
LEGACY_GROUP_SESSION_PREFIX = "group_"
SLACK_DEFAULT_TEXT_FALLBACKS = {
"safe_text": "message",
"image": "[image]",
"file_template": "[file:{name}]",
"generic": "[message]",
"image_upload_failed": "Image upload failed",
"file_upload_failed": "File upload failed",
}
SLACK_SAFE_TEXT_FALLBACK = SLACK_DEFAULT_TEXT_FALLBACKS["safe_text"]


def build_slack_text_fallbacks(overrides: dict | None = None) -> dict[str, str]:
"""Build Slack text fallback rules.

Only keys defined in `SLACK_DEFAULT_TEXT_FALLBACKS` are honored; unknown
override keys are intentionally ignored.
"""
text_fallbacks = dict(SLACK_DEFAULT_TEXT_FALLBACKS)
if not isinstance(overrides, dict):
return text_fallbacks

for key in text_fallbacks:
candidate = overrides.get(key)
if isinstance(candidate, str) and candidate.strip():
text_fallbacks[key] = candidate
return text_fallbacks


def encode_thread_session_id(channel_id: str, thread_ts: str) -> str:
if not channel_id or not thread_ts:
return channel_id
return f"{channel_id}{THREAD_SESSION_MARKER}{thread_ts}"


def decode_slack_session_id(session_id: str) -> tuple[str, str | None]:
"""Decode a Slack session id into (channel_id, thread_ts|None)."""
if not session_id:
return "", None

if THREAD_SESSION_MARKER in session_id:
channel_id, thread_ts = session_id.split(THREAD_SESSION_MARKER, 1)
return channel_id, thread_ts or None

if session_id.startswith(LEGACY_GROUP_SESSION_PREFIX):
return session_id[len(LEGACY_GROUP_SESSION_PREFIX) :], None

return session_id, None


def resolve_target_from_event(
*,
session_id: str,
raw_message: dict,
group_id: str = "",
) -> tuple[str, str | None]:
"""Resolve target for received Slack events (uses event raw payload)."""
return resolve_slack_message_target(
session_id=session_id,
raw_message=raw_message,
group_id=group_id,
)


def resolve_target_from_session(
*,
session_id: str,
group_id: str = "",
fallback_channel_id: str = "",
) -> tuple[str, str | None]:
"""Resolve target when only session metadata is available (no raw event)."""
return resolve_slack_message_target(
session_id=session_id,
group_id=group_id,
sender_id=fallback_channel_id,
)


def resolve_slack_message_target(
session_id: str,
*,
raw_message: dict | None = None,
group_id: str = "",
sender_id: str = "",
) -> tuple[str, str | None]:
"""Backward-compatible resolver shared by legacy and new Slack call sites.

Precedence for channel: group_id > raw_message.channel > parsed(session_id) > sender_id
Precedence for thread: raw_message.thread_ts > parsed(session_id)
"""
parsed_channel_id, parsed_thread_ts = decode_slack_session_id(session_id)

raw_channel_id = ""
raw_thread_ts = None
if isinstance(raw_message, dict):
raw_channel = raw_message.get("channel")
if raw_channel not in (None, ""):
raw_channel_id = str(raw_channel)

raw_thread = raw_message.get("thread_ts")
if raw_thread not in (None, ""):
raw_thread_ts = str(raw_thread)

channel_id = group_id or raw_channel_id or parsed_channel_id or sender_id
thread_ts = raw_thread_ts or parsed_thread_ts
return channel_id, thread_ts
74 changes: 46 additions & 28 deletions astrbot/core/platform/sources/slack/slack_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.utils.webhook_utils import log_webhook_info

from ...register import register_platform_adapter
from .client import SlackSocketClient, SlackWebhookClient
from .session_codec import (
build_slack_text_fallbacks,
encode_thread_session_id,
resolve_target_from_session,
)
from .slack_event import SlackMessageEvent
from .slack_send_utils import send_with_blocks_and_fallback


@register_platform_adapter(
Expand Down Expand Up @@ -76,43 +82,48 @@ def __init__(
self.webhook_client = None

self.bot_self_id = None
# Canonical fallback configuration for Slack sends. Both adapter and event
# paths consume these via explicit arguments.
self.text_fallbacks = build_slack_text_fallbacks(
platform_config.get("text_fallbacks"),
)

async def send_by_session(
self,
session: MessageSesion,
session: MessageSession,
message_chain: MessageChain,
) -> None:
blocks, text = await SlackMessageEvent._parse_slack_blocks(
message_chain=message_chain,
channel_id, thread_ts = resolve_target_from_session(
session_id=session.session_id
)
await send_with_blocks_and_fallback(
web_client=self.web_client,
channel=channel_id,
thread_ts=thread_ts,
message_chain=message_chain,
fallbacks=self.text_fallbacks,
parse_blocks=SlackMessageEvent._parse_slack_blocks,
build_text_fallback=SlackMessageEvent._build_text_fallback_from_chain,
session_id=session.session_id,
)

try:
if session.message_type == MessageType.GROUP_MESSAGE:
# 发送到频道
channel_id = (
session.session_id.split("_")[-1]
if "_" in session.session_id
else session.session_id
)
await self.web_client.chat_postMessage(
channel=channel_id,
text=text,
blocks=blocks if blocks else None,
)
else:
# 发送私信
await self.web_client.chat_postMessage(
channel=session.session_id,
text=text,
blocks=blocks if blocks else None,
)
except Exception as e:
logger.error(f"Slack 发送消息失败: {e}")

await super().send_by_session(session, message_chain)

@staticmethod
def _unwrap_message_replied_event(event: dict) -> dict:
"""Flatten Slack message_replied envelopes for normal message processing."""
if event.get("subtype") == "message_replied":
nested_message = event.get("message")
if isinstance(nested_message, dict):
merged = dict(event)
merged.update(nested_message)
if not merged.get("channel"):
merged["channel"] = event.get("channel", "")
return merged
return event

async def convert_message(self, event: dict) -> AstrBotMessage:
event = self._unwrap_message_replied_event(event)
logger.debug(f"[slack] RawMessage {event}")

abm = AstrBotMessage()
Expand Down Expand Up @@ -146,7 +157,13 @@ async def convert_message(self, event: dict) -> AstrBotMessage:
abm.group_id = channel_id

# 设置会话ID
if abm.type == MessageType.GROUP_MESSAGE:
channel_id_for_session = str(channel_id or "")
# thread_ts may come from unwrapped `message_replied` payloads.
thread_ts = str(event.get("thread_ts", "") or "")
if thread_ts and channel_id_for_session:
# Slack threads can appear in channels and DMs.
abm.session_id = encode_thread_session_id(channel_id_for_session, thread_ts)
elif abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = user_id
Expand Down Expand Up @@ -418,6 +435,7 @@ async def handle_msg(self, message: AstrBotMessage) -> None:
platform_meta=self.meta(),
session_id=message.session_id,
web_client=self.web_client,
text_fallbacks=self.text_fallbacks,
)

self.commit_event(message_event)
Expand Down
Loading