Skip to content

Commit 69f8f93

Browse files
fix(assistant): improve middleware dispatch and extract AttachingAssistantKwargs
1 parent 72a90d2 commit 69f8f93

File tree

12 files changed

+493
-65
lines changed

12 files changed

+493
-65
lines changed

slack_bolt/app/app.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
2424

25-
from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities
2625
from slack_bolt.error import BoltError, BoltUnhandledRequestError
2726
from slack_bolt.lazy_listener.thread_runner import ThreadLazyListenerRunner
2827
from slack_bolt.listener.builtins import TokenRevocationListeners
@@ -83,10 +82,6 @@
8382
from slack_bolt.oauth.internals import select_consistent_installation_store
8483
from slack_bolt.oauth.oauth_settings import OAuthSettings
8584
from slack_bolt.request import BoltRequest
86-
from slack_bolt.request.payload_utils import (
87-
is_assistant_event,
88-
to_event,
89-
)
9085
from slack_bolt.response import BoltResponse
9186
from slack_bolt.util.utils import (
9287
create_web_client,
@@ -1398,20 +1393,6 @@ def _init_context(self, req: BoltRequest):
13981393
# It is intended for apps that start lazy listeners from their custom global middleware.
13991394
req.context["listener_runner"] = self.listener_runner
14001395

1401-
# For AI Agents & Assistants
1402-
if is_assistant_event(req.body):
1403-
assistant = AssistantUtilities(
1404-
payload=to_event(req.body), # type:ignore[arg-type]
1405-
context=req.context,
1406-
thread_context_store=self._assistant_thread_context_store,
1407-
)
1408-
req.context["say"] = assistant.say
1409-
req.context["set_status"] = assistant.set_status
1410-
req.context["set_title"] = assistant.set_title
1411-
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
1412-
req.context["get_thread_context"] = assistant.get_thread_context
1413-
req.context["save_thread_context"] = assistant.save_thread_context
1414-
14151396
@staticmethod
14161397
def _to_listener_functions(
14171398
kwargs: dict,

slack_bolt/app/async_app.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from aiohttp import web
99

1010
from slack_bolt.app.async_server import AsyncSlackAppServer
11-
from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities
1211
from slack_bolt.context.assistant.thread_context_store.async_store import (
1312
AsyncAssistantThreadContextStore,
1413
)
@@ -30,7 +29,6 @@
3029
AsyncMessageListenerMatches,
3130
)
3231
from slack_bolt.oauth.async_internals import select_consistent_installation_store
33-
from slack_bolt.request.payload_utils import is_assistant_event, to_event
3432
from slack_bolt.util.utils import get_name_for_callable, is_callable_coroutine
3533
from slack_bolt.workflows.step.async_step import (
3634
AsyncWorkflowStep,
@@ -1431,20 +1429,6 @@ def _init_context(self, req: AsyncBoltRequest):
14311429
# It is intended for apps that start lazy listeners from their custom global middleware.
14321430
req.context["listener_runner"] = self.listener_runner
14331431

1434-
# For AI Agents & Assistants
1435-
if is_assistant_event(req.body):
1436-
assistant = AsyncAssistantUtilities(
1437-
payload=to_event(req.body), # type:ignore[arg-type]
1438-
context=req.context,
1439-
thread_context_store=self._assistant_thread_context_store,
1440-
)
1441-
req.context["say"] = assistant.say
1442-
req.context["set_status"] = assistant.set_status
1443-
req.context["set_title"] = assistant.set_title
1444-
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
1445-
req.context["get_thread_context"] = assistant.get_thread_context
1446-
req.context["save_thread_context"] = assistant.save_thread_context
1447-
14481432
@staticmethod
14491433
def _to_listener_functions(
14501434
kwargs: dict,

slack_bolt/context/async_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def handle_button_clicks(ack, say):
110110
Callable `say()` function
111111
"""
112112
if "say" not in self:
113-
self["say"] = AsyncSay(client=self.client, channel=self.channel_id, thread_ts=self.thread_ts)
113+
self["say"] = AsyncSay(client=self.client, channel=self.channel_id)
114114
return self["say"]
115115

116116
@property

slack_bolt/context/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def handle_button_clicks(ack, say):
111111
Callable `say()` function
112112
"""
113113
if "say" not in self:
114-
self["say"] = Say(client=self.client, channel=self.channel_id, thread_ts=self.thread_ts)
114+
self["say"] = Say(client=self.client, channel=self.channel_id)
115115
return self["say"]
116116

117117
@property

slack_bolt/middleware/assistant/assistant.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
88
from slack_bolt.listener_matcher.builtins import build_listener_matcher
99

10+
from slack_bolt.middleware.assistant.attaching_assistant_kwargs import AttachingAssistantKwargs
1011
from slack_bolt.request.request import BoltRequest
1112
from slack_bolt.response.response import BoltResponse
1213
from slack_bolt.listener_matcher import CustomListenerMatcher
@@ -236,6 +237,15 @@ def process( # type:ignore[return]
236237
if listeners is not None:
237238
for listener in listeners:
238239
if listener.matches(req=req, resp=resp):
240+
middleware_resp, next_was_not_called = listener.run_middleware(req=req, resp=resp)
241+
if next_was_not_called:
242+
if middleware_resp is not None:
243+
return middleware_resp
244+
# The listener middleware didn't call next().
245+
# This means the listener is not for this incoming request.
246+
continue
247+
if middleware_resp is not None:
248+
resp = middleware_resp
239249
return listener_runner.run(
240250
request=req,
241251
response=resp,
@@ -262,6 +272,7 @@ def build_listener(
262272
return listener_or_functions
263273
elif isinstance(listener_or_functions, list):
264274
middleware = middleware if middleware else []
275+
middleware.insert(0, AttachingAssistantKwargs(self.thread_context_store))
265276
functions = listener_or_functions
266277
ack_function = functions.pop(0)
267278

slack_bolt/middleware/assistant/async_assistant.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from slack_bolt.listener.asyncio_runner import AsyncioListenerRunner
1010
from slack_bolt.listener_matcher.builtins import build_listener_matcher
11+
from slack_bolt.middleware.assistant.async_attaching_assistant_kwargs import AsyncAttachingAssistantKwargs
1112
from slack_bolt.request.async_request import AsyncBoltRequest
1213
from slack_bolt.response import BoltResponse
1314
from slack_bolt.error import BoltError
@@ -265,6 +266,15 @@ async def async_process( # type:ignore[return]
265266
if listeners is not None:
266267
for listener in listeners:
267268
if listener is not None and await listener.async_matches(req=req, resp=resp):
269+
middleware_resp, next_was_not_called = await listener.run_async_middleware(req=req, resp=resp)
270+
if next_was_not_called:
271+
if middleware_resp is not None:
272+
return middleware_resp
273+
# The listener middleware didn't call next().
274+
# This means the listener is not for this incoming request.
275+
continue
276+
if middleware_resp is not None:
277+
resp = middleware_resp
268278
return await listener_runner.run(
269279
request=req,
270280
response=resp,
@@ -291,6 +301,7 @@ def build_listener(
291301
return listener_or_functions
292302
elif isinstance(listener_or_functions, list):
293303
middleware = middleware if middleware else []
304+
middleware.insert(0, AsyncAttachingAssistantKwargs(self.thread_context_store))
294305
functions = listener_or_functions
295306
ack_function = functions.pop(0)
296307

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional, Callable, Awaitable
2+
3+
from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities
4+
from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore
5+
from slack_bolt.middleware.async_middleware import AsyncMiddleware
6+
from slack_bolt.request.async_request import AsyncBoltRequest
7+
from slack_bolt.request.payload_utils import to_event
8+
from slack_bolt.response import BoltResponse
9+
10+
11+
class AsyncAttachingAssistantKwargs(AsyncMiddleware):
12+
13+
thread_context_store: Optional[AsyncAssistantThreadContextStore]
14+
15+
def __init__(self, thread_context_store: Optional[AsyncAssistantThreadContextStore]):
16+
self.thread_context_store = thread_context_store
17+
18+
async def async_process(
19+
self,
20+
*,
21+
req: AsyncBoltRequest,
22+
resp: BoltResponse,
23+
next: Callable[[], Awaitable[BoltResponse]],
24+
) -> Optional[BoltResponse]:
25+
event = to_event(req.body)
26+
if event is not None:
27+
assistant = AsyncAssistantUtilities(
28+
payload=event,
29+
context=req.context,
30+
thread_context_store=self.thread_context_store,
31+
)
32+
req.context["say"] = assistant.say
33+
req.context["set_status"] = assistant.set_status
34+
req.context["set_title"] = assistant.set_title
35+
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
36+
req.context["get_thread_context"] = assistant.get_thread_context
37+
req.context["save_thread_context"] = assistant.save_thread_context
38+
return await next()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Optional, Callable
2+
3+
from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities
4+
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
5+
from slack_bolt.middleware import Middleware
6+
from slack_bolt.request.payload_utils import to_event
7+
from slack_bolt.request.request import BoltRequest
8+
from slack_bolt.response.response import BoltResponse
9+
10+
11+
class AttachingAssistantKwargs(Middleware):
12+
13+
thread_context_store: Optional[AssistantThreadContextStore]
14+
15+
def __init__(self, thread_context_store: Optional[AssistantThreadContextStore]):
16+
self.thread_context_store = thread_context_store
17+
18+
def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]) -> Optional[BoltResponse]:
19+
event = to_event(req.body)
20+
if event is not None:
21+
assistant = AssistantUtilities(
22+
payload=event,
23+
context=req.context,
24+
thread_context_store=self.thread_context_store,
25+
)
26+
req.context["say"] = assistant.say
27+
req.context["set_status"] = assistant.set_status
28+
req.context["set_title"] = assistant.set_title
29+
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
30+
req.context["get_thread_context"] = assistant.get_thread_context
31+
req.context["save_thread_context"] = assistant.save_thread_context
32+
return next()

slack_bolt/request/internals.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from urllib.parse import parse_qsl, parse_qs
44

55
from slack_bolt.context import BoltContext
6-
from slack_bolt.request.payload_utils import is_assistant_event
76

87

98
def parse_query(query: Optional[Union[str, Dict[str, str], Dict[str, Sequence[str]]]]) -> Dict[str, Sequence[str]]:
@@ -215,33 +214,17 @@ def extract_channel_id(payload: Dict[str, Any]) -> Optional[str]:
215214

216215

217216
def extract_thread_ts(payload: Dict[str, Any]) -> Optional[str]:
218-
# This utility initially supports only the use cases for AI assistants, but it may be fine to add more patterns.
219-
# That said, note that thread_ts is always required for assistant threads, but it's not for channels.
220-
# Thus, blindly setting this thread_ts to say utility can break existing apps' behaviors.
221-
#
222-
# The BoltAgent class handles non-assistant thread_ts separately by reading from the event directly,
223-
# allowing it to work correctly without affecting say() behavior.
224-
if is_assistant_event(payload):
225-
event = payload["event"]
226-
if (
227-
event.get("assistant_thread") is not None
228-
and event["assistant_thread"].get("channel_id") is not None
229-
and event["assistant_thread"].get("thread_ts") is not None
230-
):
231-
# assistant_thread_started, assistant_thread_context_changed
232-
# "assistant_thread" property can exist for message event without channel_id and thread_ts
233-
# Thus, the above if check verifies these properties exist
234-
return event["assistant_thread"]["thread_ts"]
235-
elif event.get("channel") is not None:
236-
if event.get("thread_ts") is not None:
237-
# message in an assistant thread
238-
return event["thread_ts"]
239-
elif event.get("message", {}).get("thread_ts") is not None:
240-
# message_changed
241-
return event["message"]["thread_ts"]
242-
elif event.get("previous_message", {}).get("thread_ts") is not None:
243-
# message_deleted
244-
return event["previous_message"]["thread_ts"]
217+
thread_ts = payload.get("thread_ts")
218+
if thread_ts is not None:
219+
return thread_ts
220+
if payload.get("event") is not None:
221+
return extract_thread_ts(payload["event"])
222+
if isinstance(payload.get("assistant_thread"), dict):
223+
return extract_thread_ts(payload["assistant_thread"])
224+
if isinstance(payload.get("message"), dict):
225+
return extract_thread_ts(payload["message"])
226+
if isinstance(payload.get("previous_message"), dict):
227+
return extract_thread_ts(payload["previous_message"])
245228
return None
246229

247230

tests/scenario_tests/test_events_assistant.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from time import sleep
2+
from typing import Callable
23

34
from slack_sdk.web import WebClient
45

56
from slack_bolt import App, BoltRequest, Assistant, Say, SetSuggestedPrompts, SetStatus, BoltContext
7+
from slack_bolt.middleware import Middleware
8+
from slack_bolt.request import BoltRequest as BoltRequestType
9+
from slack_bolt.response import BoltResponse
610
from tests.mock_web_api_server import (
711
setup_mock_web_api_server,
812
cleanup_mock_web_api_server,
@@ -44,6 +48,7 @@ def assert_target_called():
4448
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, context: BoltContext):
4549
assert context.channel_id == "D111"
4650
assert context.thread_ts == "1726133698.626339"
51+
assert say.thread_ts == context.thread_ts
4752
say("Hi, how can I help you today?")
4853
set_suggested_prompts(prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}])
4954
set_suggested_prompts(
@@ -61,6 +66,7 @@ def handle_thread_context_changed(context: BoltContext):
6166
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
6267
assert context.channel_id == "D111"
6368
assert context.thread_ts == "1726133698.626339"
69+
assert say.thread_ts == context.thread_ts
6470
try:
6571
set_status("is typing...")
6672
say("Here you are!")
@@ -102,6 +108,86 @@ def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
102108
response = app.dispatch(request)
103109
assert response.status == 404
104110

111+
def test_assistant_threads_with_custom_listener_middleware(self):
112+
app = App(client=self.web_client)
113+
assistant = Assistant()
114+
115+
state = {"called": False, "middleware_called": False}
116+
117+
def assert_target_called():
118+
count = 0
119+
while state["called"] is False and count < 20:
120+
sleep(0.1)
121+
count += 1
122+
assert state["called"] is True
123+
state["called"] = False
124+
125+
class TestMiddleware(Middleware):
126+
def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]):
127+
state["middleware_called"] = True
128+
# Verify assistant utilities are available
129+
assert req.context.get("set_status") is not None
130+
assert req.context.get("set_title") is not None
131+
assert req.context.get("set_suggested_prompts") is not None
132+
assert req.context.get("get_thread_context") is not None
133+
assert req.context.get("save_thread_context") is not None
134+
return next()
135+
136+
@assistant.thread_started(middleware=[TestMiddleware()])
137+
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, context: BoltContext):
138+
assert context.channel_id == "D111"
139+
assert context.thread_ts == "1726133698.626339"
140+
assert say.thread_ts == context.thread_ts
141+
say("Hi, how can I help you today?")
142+
set_suggested_prompts(prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}])
143+
state["called"] = True
144+
145+
@assistant.user_message(middleware=[TestMiddleware()])
146+
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
147+
assert context.channel_id == "D111"
148+
assert context.thread_ts == "1726133698.626339"
149+
assert say.thread_ts == context.thread_ts
150+
set_status("is typing...")
151+
say("Here you are!")
152+
state["called"] = True
153+
154+
app.assistant(assistant)
155+
156+
request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
157+
response = app.dispatch(request)
158+
assert response.status == 200
159+
assert_target_called()
160+
assert state["middleware_called"] is True
161+
state["middleware_called"] = False
162+
163+
request = BoltRequest(body=user_message_event_body, mode="socket_mode")
164+
response = app.dispatch(request)
165+
assert response.status == 200
166+
assert_target_called()
167+
assert state["middleware_called"] is True
168+
169+
def test_assistant_threads_custom_middleware_can_short_circuit(self):
170+
app = App(client=self.web_client)
171+
assistant = Assistant()
172+
173+
state = {"handler_called": False}
174+
175+
class BlockingMiddleware(Middleware):
176+
def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]):
177+
# Intentionally not calling next() to short-circuit
178+
return BoltResponse(status=200)
179+
180+
@assistant.thread_started(middleware=[BlockingMiddleware()])
181+
def start_thread(say: Say, context: BoltContext):
182+
state["handler_called"] = True
183+
184+
app.assistant(assistant)
185+
186+
request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
187+
response = app.dispatch(request)
188+
assert response.status == 200
189+
assert state["handler_called"] is False
190+
105191

106192
def build_payload(event: dict) -> dict:
107193
return {

0 commit comments

Comments
 (0)