Skip to content
Merged
48 changes: 42 additions & 6 deletions slack_bolt/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk import WebClient
from slack_sdk.web import SlackResponse
Expand All @@ -11,9 +11,6 @@ class BoltAgent:
Experimental:
This API is experimental and may change in future releases.

FIXME: chat_stream() only works when thread_ts is available (DMs and threaded replies).
It does not work on channel messages because ts is not provided to BoltAgent yet.

@app.event("app_mention")
def handle_mention(agent):
stream = agent.chat_stream()
Expand All @@ -27,12 +24,14 @@ def __init__(
client: WebClient,
channel_id: Optional[str] = None,
thread_ts: Optional[str] = None,
ts: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
):
self._client = client
self._channel_id = channel_id
self._thread_ts = thread_ts
self._ts = ts
self._team_id = team_id
self._user_id = user_id

Expand Down Expand Up @@ -67,7 +66,7 @@ def chat_stream(
# Argument validation is delegated to chat_stream() and the API
return self._client.chat_stream(
channel=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
recipient_team_id=recipient_team_id or self._team_id,
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
Expand Down Expand Up @@ -96,8 +95,45 @@ def set_status(
"""
return self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)

def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``SlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
45 changes: 42 additions & 3 deletions slack_bolt/agent/async_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient
from slack_sdk.web.async_chat_stream import AsyncChatStream
Expand All @@ -23,12 +23,14 @@ def __init__(
client: AsyncWebClient,
channel_id: Optional[str] = None,
thread_ts: Optional[str] = None,
ts: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
):
self._client = client
self._channel_id = channel_id
self._thread_ts = thread_ts
self._ts = ts
self._team_id = team_id
self._user_id = user_id

Expand Down Expand Up @@ -63,7 +65,7 @@ async def chat_stream(
# Argument validation is delegated to chat_stream() and the API
return await self._client.chat_stream(
channel=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
recipient_team_id=recipient_team_id or self._team_id,
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
Expand Down Expand Up @@ -92,8 +94,45 @@ async def set_status(
"""
return await self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)

async def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> AsyncSlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``AsyncSlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return await self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
5 changes: 4 additions & 1 deletion slack_bolt/kwargs_injection/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ def build_async_required_kwargs(
if "agent" in required_arg_names:
from slack_bolt.agent.async_agent import AsyncBoltAgent

event = request.body.get("event", {})

all_available_args["agent"] = AsyncBoltAgent(
client=request.context.client,
channel_id=request.context.channel_id,
thread_ts=request.context.thread_ts,
thread_ts=request.context.thread_ts or event.get("thread_ts"),
ts=event.get("ts"),
Comment on lines +97 to +98
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: such a clean implementation! 🎉

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwbrooks Thanks! I'm still keeping a note that we might want to add ts to context but for now I think this is an alright approach 👾

team_id=request.context.team_id,
user_id=request.context.user_id,
)
Expand Down
5 changes: 4 additions & 1 deletion slack_bolt/kwargs_injection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,13 @@ def build_required_kwargs(
if "agent" in required_arg_names:
from slack_bolt.agent.agent import BoltAgent

event = request.body.get("event", {})

all_available_args["agent"] = BoltAgent(
client=request.context.client,
channel_id=request.context.channel_id,
thread_ts=request.context.thread_ts,
thread_ts=request.context.thread_ts or event.get("thread_ts"),
ts=event.get("ts"),
team_id=request.context.team_id,
user_id=request.context.user_id,
)
Expand Down
3 changes: 3 additions & 0 deletions slack_bolt/request/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def extract_thread_ts(payload: Dict[str, Any]) -> Optional[str]:
# This utility initially supports only the use cases for AI assistants, but it may be fine to add more patterns.
# That said, note that thread_ts is always required for assistant threads, but it's not for channels.
# Thus, blindly setting this thread_ts to say utility can break existing apps' behaviors.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👁️‍🗨️ thought: I'm surprised that say posts top-level messages in response to threaded messages by default TBH!

I agree that a "fix" for this, to respond in thread if a thread_ts is present, might cause new behavior for apps but am wondering if this is intended behavior or something to ponder changing in the future?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm good question! id love to hear what changing it in the future would look like 🤔 to me it makes sense for apps to respond in thread

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srtaalej I'm most interested in the context.thread_ts containing the thread_ts from all events and not just assistant related ones 🤖

We might find the say helper to make use of that or not - I wonder if we can make this a non-breaking change - but we should document changes in either case!

#
# The BoltAgent class handles non-assistant thread_ts separately by reading from the event directly,
# allowing it to work correctly without affecting say() behavior.
if is_assistant_event(payload):
event = payload["event"]
if (
Expand Down
157 changes: 157 additions & 0 deletions tests/slack_bolt/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,51 @@ def test_chat_stream_passes_extra_kwargs(self):
buffer_size=512,
)

def test_chat_stream_falls_back_to_ts(self):
"""When thread_ts is not set, chat_stream() falls back to ts."""
client = MagicMock(spec=WebClient)
client.chat_stream.return_value = MagicMock(spec=ChatStream)

agent = BoltAgent(
client=client,
channel_id="C111",
team_id="T111",
ts="1111111111.111111",
user_id="W222",
)
stream = agent.chat_stream()

client.chat_stream.assert_called_once_with(
channel="C111",
thread_ts="1111111111.111111",
recipient_team_id="T111",
recipient_user_id="W222",
)
assert stream is not None

def test_chat_stream_prefers_thread_ts_over_ts(self):
"""thread_ts takes priority over ts."""
client = MagicMock(spec=WebClient)
client.chat_stream.return_value = MagicMock(spec=ChatStream)

agent = BoltAgent(
client=client,
channel_id="C111",
team_id="T111",
thread_ts="1234567890.123456",
ts="1111111111.111111",
user_id="W222",
)
stream = agent.chat_stream()

client.chat_stream.assert_called_once_with(
channel="C111",
thread_ts="1234567890.123456",
recipient_team_id="T111",
recipient_user_id="W222",
)
assert stream is not None

def test_set_status_uses_context_defaults(self):
"""BoltAgent.set_status() passes context defaults to WebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=WebClient)
Expand Down Expand Up @@ -197,6 +242,118 @@ def test_set_status_requires_status(self):
with pytest.raises(TypeError):
agent.set_status()

def test_set_suggested_prompts_uses_context_defaults(self):
"""BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "What can you do?", "message": "What can you do?"},
{"title": "Help me write code", "message": "Help me write code"},
],
title=None,
)

def test_set_suggested_prompts_with_dict_prompts(self):
"""BoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

def test_set_suggested_prompts_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=["Hello"],
channel="C999",
thread_ts="9999999999.999999",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
)

def test_set_suggested_prompts_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
token="xoxb-override",
)

def test_set_suggested_prompts_requires_prompts(self):
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
client = MagicMock(spec=WebClient)
agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
agent.set_suggested_prompts()

def test_import_from_slack_bolt(self):
from slack_bolt import BoltAgent as ImportedBoltAgent

Expand Down
Loading