11import asyncio
22import json
3- try :
4- from unittest .mock import AsyncMock , MagicMock
5- except ImportError :
6- from mock import AsyncMock , MagicMock # type: ignore
3+ from unittest .mock import MagicMock
74
85import pytest
96from slack_sdk .web .async_client import AsyncWebClient
2017from tests .utils import remove_os_env_temporarily , restore_os_env
2118
2219
20+ def _make_async_chat_stream_mock ():
21+ mock_stream = MagicMock (spec = AsyncChatStream )
22+ call_tracker = MagicMock ()
23+
24+ async def fake_chat_stream (** kwargs ):
25+ call_tracker (** kwargs )
26+ return mock_stream
27+
28+ return fake_chat_stream , call_tracker , mock_stream
29+
30+
2331class TestAsyncEventsAgent :
2432 valid_token = "xoxb-valid"
2533 mock_api_server_base_url = "http://localhost:8888"
@@ -68,7 +76,7 @@ async def handle_mention(agent: AsyncBoltAgent, context: AsyncBoltContext):
6876 async def test_agent_chat_stream_uses_context_defaults (self ):
6977 """AsyncBoltAgent.chat_stream() passes context defaults to AsyncWebClient.chat_stream()."""
7078 client = MagicMock (spec = AsyncWebClient )
71- client .chat_stream = AsyncMock ( return_value = MagicMock ( spec = AsyncChatStream ) )
79+ client .chat_stream , call_tracker , _ = _make_async_chat_stream_mock ( )
7280
7381 agent = AsyncBoltAgent (
7482 client = client ,
@@ -79,7 +87,7 @@ async def test_agent_chat_stream_uses_context_defaults(self):
7987 )
8088 stream = await agent .chat_stream ()
8189
82- client . chat_stream .assert_called_once_with (
90+ call_tracker .assert_called_once_with (
8391 channel = "C111" ,
8492 thread_ts = "1234567890.123456" ,
8593 recipient_team_id = "T111" ,
@@ -91,7 +99,7 @@ async def test_agent_chat_stream_uses_context_defaults(self):
9199 async def test_agent_chat_stream_overrides_context_defaults (self ):
92100 """Explicit kwargs to chat_stream() override context defaults."""
93101 client = MagicMock (spec = AsyncWebClient )
94- client .chat_stream = AsyncMock ( return_value = MagicMock ( spec = AsyncChatStream ) )
102+ client .chat_stream , call_tracker , _ = _make_async_chat_stream_mock ( )
95103
96104 agent = AsyncBoltAgent (
97105 client = client ,
@@ -107,7 +115,7 @@ async def test_agent_chat_stream_overrides_context_defaults(self):
107115 recipient_user_id = "U999" ,
108116 )
109117
110- client . chat_stream .assert_called_once_with (
118+ call_tracker .assert_called_once_with (
111119 channel = "C999" ,
112120 thread_ts = "9999999999.999999" ,
113121 recipient_team_id = "T999" ,
@@ -119,7 +127,7 @@ async def test_agent_chat_stream_overrides_context_defaults(self):
119127 async def test_agent_chat_stream_passes_extra_kwargs (self ):
120128 """Extra kwargs are forwarded to AsyncWebClient.chat_stream()."""
121129 client = MagicMock (spec = AsyncWebClient )
122- client .chat_stream = AsyncMock ( return_value = MagicMock ( spec = AsyncChatStream ) )
130+ client .chat_stream , call_tracker , _ = _make_async_chat_stream_mock ( )
123131
124132 agent = AsyncBoltAgent (
125133 client = client ,
@@ -130,7 +138,7 @@ async def test_agent_chat_stream_passes_extra_kwargs(self):
130138 )
131139 await agent .chat_stream (buffer_size = 512 )
132140
133- client . chat_stream .assert_called_once_with (
141+ call_tracker .assert_called_once_with (
134142 channel = "C111" ,
135143 thread_ts = "1234567890.123456" ,
136144 recipient_team_id = "T111" ,
0 commit comments