Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion slack_bolt/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk import WebClient
from slack_sdk.web import SlackResponse
from slack_sdk.web.chat_stream import ChatStream


Expand Down Expand Up @@ -71,3 +72,69 @@ def chat_stream(
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
)

def set_status(
self,
*,
status: str,
loading_messages: Optional[List[str]] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets the status of an assistant thread.

Args:
status: The status text to display.
loading_messages: Optional list of loading messages to cycle through.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setStatus()``.

Returns:
``SlackResponse`` from the API call.
"""
return self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)

def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 note: I'm noticing this is different than the expected API value!

🔗 https://docs.slack.dev/reference/methods/assistant.threads.setSuggestedPrompts/

thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``SlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
69 changes: 68 additions & 1 deletion slack_bolt/agent/async_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk.web import SlackResponse
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.web.async_chat_stream import AsyncChatStream

Expand Down Expand Up @@ -68,3 +69,69 @@ async def chat_stream(
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
)

async def set_status(
self,
*,
status: str,
loading_messages: Optional[List[str]] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets the status of an assistant thread.

Args:
status: The status text to display.
loading_messages: Optional list of loading messages to cycle through.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setStatus()``.

Returns:
``SlackResponse`` from the API call.
"""
return await self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)

async def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``SlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return await self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
217 changes: 217 additions & 0 deletions tests/slack_bolt/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,223 @@ def test_chat_stream_passes_extra_kwargs(self):
buffer_size=512,
)

def test_set_status_uses_context_defaults(self):
"""BoltAgent.set_status() passes context defaults to WebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(status="Thinking...")

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
)

def test_set_status_with_loading_messages(self):
"""BoltAgent.set_status() forwards loading_messages."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

def test_set_status_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(
status="Thinking...",
channel="C999",
thread_ts="9999999999.999999",
)

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
status="Thinking...",
loading_messages=None,
)

def test_set_status_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to WebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(status="Thinking...", token="xoxb-override")

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
token="xoxb-override",
)

def test_set_status_requires_status(self):
"""set_status() raises TypeError when status is not provided."""
client = MagicMock(spec=WebClient)
agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
agent.set_status()

def test_set_suggested_prompts_uses_context_defaults(self):
"""BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "What can you do?", "message": "What can you do?"},
{"title": "Help me write code", "message": "Help me write code"},
],
title=None,
)

def test_set_suggested_prompts_with_dict_prompts(self):
"""BoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

def test_set_suggested_prompts_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=["Hello"],
channel="C999",
thread_ts="9999999999.999999",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
)

def test_set_suggested_prompts_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
token="xoxb-override",
)

def test_set_suggested_prompts_requires_prompts(self):
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
client = MagicMock(spec=WebClient)
agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
agent.set_suggested_prompts()

def test_import_from_slack_bolt(self):
from slack_bolt import BoltAgent as ImportedBoltAgent

Expand Down
Loading
Loading