Skip to content

Commit 92bff60

Browse files
zimegsrtaalej
andauthored
feat(agent): add set_suggested_prompts helper (#1442)
Co-authored-by: Ale Mercado <maria.mercado@slack-corp.com>
1 parent 5cb6182 commit 92bff60

File tree

4 files changed

+305
-2
lines changed

4 files changed

+305
-2
lines changed

slack_bolt/agent/agent.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Dict, List, Optional, Sequence, Union
22

33
from slack_sdk import WebClient
44
from slack_sdk.web import SlackResponse
@@ -101,3 +101,40 @@ def set_status(
101101
loading_messages=loading_messages,
102102
**kwargs,
103103
)
104+
105+
def set_suggested_prompts(
106+
self,
107+
*,
108+
prompts: Sequence[Union[str, Dict[str, str]]],
109+
title: Optional[str] = None,
110+
channel: Optional[str] = None,
111+
thread_ts: Optional[str] = None,
112+
**kwargs,
113+
) -> SlackResponse:
114+
"""Sets suggested prompts for an assistant thread.
115+
116+
Args:
117+
prompts: A sequence of prompts. Each prompt can be either a string
118+
(used as both title and message) or a dict with 'title' and 'message' keys.
119+
title: Optional title for the suggested prompts section.
120+
channel: Channel ID. Defaults to the channel from the event context.
121+
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
122+
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``.
123+
124+
Returns:
125+
``SlackResponse`` from the API call.
126+
"""
127+
prompts_arg: List[Dict[str, str]] = []
128+
for prompt in prompts:
129+
if isinstance(prompt, str):
130+
prompts_arg.append({"title": prompt, "message": prompt})
131+
else:
132+
prompts_arg.append(prompt)
133+
134+
return self._client.assistant_threads_setSuggestedPrompts(
135+
channel_id=channel or self._channel_id, # type: ignore[arg-type]
136+
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
137+
prompts=prompts_arg,
138+
title=title,
139+
**kwargs,
140+
)

slack_bolt/agent/async_agent.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Dict, List, Optional, Sequence, Union
22

33
from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient
44
from slack_sdk.web.async_chat_stream import AsyncChatStream
@@ -97,3 +97,40 @@ async def set_status(
9797
loading_messages=loading_messages,
9898
**kwargs,
9999
)
100+
101+
async def set_suggested_prompts(
102+
self,
103+
*,
104+
prompts: Sequence[Union[str, Dict[str, str]]],
105+
title: Optional[str] = None,
106+
channel: Optional[str] = None,
107+
thread_ts: Optional[str] = None,
108+
**kwargs,
109+
) -> AsyncSlackResponse:
110+
"""Sets suggested prompts for an assistant thread.
111+
112+
Args:
113+
prompts: A sequence of prompts. Each prompt can be either a string
114+
(used as both title and message) or a dict with 'title' and 'message' keys.
115+
title: Optional title for the suggested prompts section.
116+
channel: Channel ID. Defaults to the channel from the event context.
117+
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
118+
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``.
119+
120+
Returns:
121+
``AsyncSlackResponse`` from the API call.
122+
"""
123+
prompts_arg: List[Dict[str, str]] = []
124+
for prompt in prompts:
125+
if isinstance(prompt, str):
126+
prompts_arg.append({"title": prompt, "message": prompt})
127+
else:
128+
prompts_arg.append(prompt)
129+
130+
return await self._client.assistant_threads_setSuggestedPrompts(
131+
channel_id=channel or self._channel_id, # type: ignore[arg-type]
132+
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
133+
prompts=prompts_arg,
134+
title=title,
135+
**kwargs,
136+
)

tests/slack_bolt/agent/test_agent.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,118 @@ def test_set_status_requires_status(self):
197197
with pytest.raises(TypeError):
198198
agent.set_status()
199199

200+
def test_set_suggested_prompts_uses_context_defaults(self):
201+
"""BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts()."""
202+
client = MagicMock(spec=WebClient)
203+
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()
204+
205+
agent = BoltAgent(
206+
client=client,
207+
channel_id="C111",
208+
thread_ts="1234567890.123456",
209+
team_id="T111",
210+
user_id="W222",
211+
)
212+
agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])
213+
214+
client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
215+
channel_id="C111",
216+
thread_ts="1234567890.123456",
217+
prompts=[
218+
{"title": "What can you do?", "message": "What can you do?"},
219+
{"title": "Help me write code", "message": "Help me write code"},
220+
],
221+
title=None,
222+
)
223+
224+
def test_set_suggested_prompts_with_dict_prompts(self):
225+
"""BoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
226+
client = MagicMock(spec=WebClient)
227+
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()
228+
229+
agent = BoltAgent(
230+
client=client,
231+
channel_id="C111",
232+
thread_ts="1234567890.123456",
233+
team_id="T111",
234+
user_id="W222",
235+
)
236+
agent.set_suggested_prompts(
237+
prompts=[
238+
{"title": "Short title", "message": "A much longer message for this prompt"},
239+
],
240+
title="Suggestions",
241+
)
242+
243+
client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
244+
channel_id="C111",
245+
thread_ts="1234567890.123456",
246+
prompts=[
247+
{"title": "Short title", "message": "A much longer message for this prompt"},
248+
],
249+
title="Suggestions",
250+
)
251+
252+
def test_set_suggested_prompts_overrides_context_defaults(self):
253+
"""Explicit channel/thread_ts override context defaults."""
254+
client = MagicMock(spec=WebClient)
255+
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()
256+
257+
agent = BoltAgent(
258+
client=client,
259+
channel_id="C111",
260+
thread_ts="1234567890.123456",
261+
team_id="T111",
262+
user_id="W222",
263+
)
264+
agent.set_suggested_prompts(
265+
prompts=["Hello"],
266+
channel="C999",
267+
thread_ts="9999999999.999999",
268+
)
269+
270+
client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
271+
channel_id="C999",
272+
thread_ts="9999999999.999999",
273+
prompts=[{"title": "Hello", "message": "Hello"}],
274+
title=None,
275+
)
276+
277+
def test_set_suggested_prompts_passes_extra_kwargs(self):
278+
"""Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts()."""
279+
client = MagicMock(spec=WebClient)
280+
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()
281+
282+
agent = BoltAgent(
283+
client=client,
284+
channel_id="C111",
285+
thread_ts="1234567890.123456",
286+
team_id="T111",
287+
user_id="W222",
288+
)
289+
agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")
290+
291+
client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
292+
channel_id="C111",
293+
thread_ts="1234567890.123456",
294+
prompts=[{"title": "Hello", "message": "Hello"}],
295+
title=None,
296+
token="xoxb-override",
297+
)
298+
299+
def test_set_suggested_prompts_requires_prompts(self):
300+
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
301+
client = MagicMock(spec=WebClient)
302+
agent = BoltAgent(
303+
client=client,
304+
channel_id="C111",
305+
thread_ts="1234567890.123456",
306+
team_id="T111",
307+
user_id="W222",
308+
)
309+
with pytest.raises(TypeError):
310+
agent.set_suggested_prompts()
311+
200312
def test_import_from_slack_bolt(self):
201313
from slack_bolt import BoltAgent as ImportedBoltAgent
202314

tests/slack_bolt_async/agent/test_async_agent.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,123 @@ async def test_set_status_requires_status(self):
228228
with pytest.raises(TypeError):
229229
await agent.set_status()
230230

231+
@pytest.mark.asyncio
232+
async def test_set_suggested_prompts_uses_context_defaults(self):
233+
"""AsyncBoltAgent.set_suggested_prompts() passes context defaults to AsyncWebClient.assistant_threads_setSuggestedPrompts()."""
234+
client = MagicMock(spec=AsyncWebClient)
235+
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()
236+
237+
agent = AsyncBoltAgent(
238+
client=client,
239+
channel_id="C111",
240+
thread_ts="1234567890.123456",
241+
team_id="T111",
242+
user_id="W222",
243+
)
244+
await agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])
245+
246+
call_tracker.assert_called_once_with(
247+
channel_id="C111",
248+
thread_ts="1234567890.123456",
249+
prompts=[
250+
{"title": "What can you do?", "message": "What can you do?"},
251+
{"title": "Help me write code", "message": "Help me write code"},
252+
],
253+
title=None,
254+
)
255+
256+
@pytest.mark.asyncio
257+
async def test_set_suggested_prompts_with_dict_prompts(self):
258+
"""AsyncBoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
259+
client = MagicMock(spec=AsyncWebClient)
260+
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()
261+
262+
agent = AsyncBoltAgent(
263+
client=client,
264+
channel_id="C111",
265+
thread_ts="1234567890.123456",
266+
team_id="T111",
267+
user_id="W222",
268+
)
269+
await agent.set_suggested_prompts(
270+
prompts=[
271+
{"title": "Short title", "message": "A much longer message for this prompt"},
272+
],
273+
title="Suggestions",
274+
)
275+
276+
call_tracker.assert_called_once_with(
277+
channel_id="C111",
278+
thread_ts="1234567890.123456",
279+
prompts=[
280+
{"title": "Short title", "message": "A much longer message for this prompt"},
281+
],
282+
title="Suggestions",
283+
)
284+
285+
@pytest.mark.asyncio
286+
async def test_set_suggested_prompts_overrides_context_defaults(self):
287+
"""Explicit channel/thread_ts override context defaults."""
288+
client = MagicMock(spec=AsyncWebClient)
289+
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()
290+
291+
agent = AsyncBoltAgent(
292+
client=client,
293+
channel_id="C111",
294+
thread_ts="1234567890.123456",
295+
team_id="T111",
296+
user_id="W222",
297+
)
298+
await agent.set_suggested_prompts(
299+
prompts=["Hello"],
300+
channel="C999",
301+
thread_ts="9999999999.999999",
302+
)
303+
304+
call_tracker.assert_called_once_with(
305+
channel_id="C999",
306+
thread_ts="9999999999.999999",
307+
prompts=[{"title": "Hello", "message": "Hello"}],
308+
title=None,
309+
)
310+
311+
@pytest.mark.asyncio
312+
async def test_set_suggested_prompts_passes_extra_kwargs(self):
313+
"""Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setSuggestedPrompts()."""
314+
client = MagicMock(spec=AsyncWebClient)
315+
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()
316+
317+
agent = AsyncBoltAgent(
318+
client=client,
319+
channel_id="C111",
320+
thread_ts="1234567890.123456",
321+
team_id="T111",
322+
user_id="W222",
323+
)
324+
await agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")
325+
326+
call_tracker.assert_called_once_with(
327+
channel_id="C111",
328+
thread_ts="1234567890.123456",
329+
prompts=[{"title": "Hello", "message": "Hello"}],
330+
title=None,
331+
token="xoxb-override",
332+
)
333+
334+
@pytest.mark.asyncio
335+
async def test_set_suggested_prompts_requires_prompts(self):
336+
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
337+
client = MagicMock(spec=AsyncWebClient)
338+
agent = AsyncBoltAgent(
339+
client=client,
340+
channel_id="C111",
341+
thread_ts="1234567890.123456",
342+
team_id="T111",
343+
user_id="W222",
344+
)
345+
with pytest.raises(TypeError):
346+
await agent.set_suggested_prompts()
347+
231348
@pytest.mark.asyncio
232349
async def test_import_from_agent_module(self):
233350
from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent

0 commit comments

Comments
 (0)