Skip to content

Commit af7bea2

Browse files
committed
test: confirm async implementation matches sync
1 parent ee25536 commit af7bea2

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import json
2+
import re
3+
import unittest
4+
from urllib.parse import parse_qs, urlparse
5+
6+
from slack_sdk.web.async_client import AsyncWebClient
7+
from tests.mock_web_api_server import cleanup_mock_web_api_server, setup_mock_web_api_server
8+
from tests.slack_sdk.web.mock_web_api_handler import MockHandler
9+
from tests.slack_sdk_async.helpers import async_test
10+
11+
12+
class ChatStreamMockHandler(MockHandler):
13+
"""Extended mock handler that captures request bodies for chat stream methods"""
14+
15+
def _handle(self):
16+
try:
17+
# put_nowait is common between Queue & asyncio.Queue, it does not need to be awaited
18+
self.server.queue.put_nowait(self.path)
19+
20+
# Standard auth and validation from parent
21+
if self.is_valid_token() and self.is_valid_user_agent():
22+
token = self.headers["authorization"].split(" ")[1]
23+
parsed_path = urlparse(self.path)
24+
len_header = self.headers.get("Content-Length") or 0
25+
content_len = int(len_header)
26+
post_body = self.rfile.read(content_len)
27+
request_body = None
28+
if post_body:
29+
try:
30+
post_body = post_body.decode("utf-8")
31+
if post_body.startswith("{"):
32+
request_body = json.loads(post_body)
33+
else:
34+
request_body = {k: v[0] for k, v in parse_qs(post_body).items()}
35+
except UnicodeDecodeError:
36+
pass
37+
else:
38+
if parsed_path and parsed_path.query:
39+
request_body = {k: v[0] for k, v in parse_qs(parsed_path.query).items()}
40+
41+
# Store request body for chat stream endpoints
42+
if self.path in ["/chat.startStream", "/chat.appendStream", "/chat.stopStream"] and request_body:
43+
if not hasattr(self.server, "chat_stream_requests"):
44+
self.server.chat_stream_requests = {}
45+
self.server.chat_stream_requests[self.path] = {
46+
"token": token,
47+
**request_body,
48+
}
49+
50+
# Load response file
51+
pattern = str(token).split("xoxb-", 1)[1]
52+
with open(f"tests/slack_sdk_fixture/web_response_{pattern}.json") as file:
53+
body = json.load(file)
54+
55+
else:
56+
body = self.invalid_auth
57+
58+
if not body:
59+
body = self.not_found
60+
61+
self.send_response(200)
62+
self.set_common_headers()
63+
self.wfile.write(json.dumps(body).encode("utf-8"))
64+
self.wfile.close()
65+
66+
except Exception as e:
67+
self.logger.error(str(e), exc_info=True)
68+
raise
69+
70+
71+
class TestAsyncChatStream(unittest.TestCase):
72+
def setUp(self):
73+
setup_mock_web_api_server(self, ChatStreamMockHandler)
74+
self.client = AsyncWebClient(
75+
token="xoxb-chat_stream_test",
76+
base_url="http://localhost:8888",
77+
)
78+
79+
def tearDown(self):
80+
cleanup_mock_web_api_server(self)
81+
82+
pattern_for_language = re.compile("python/(\\S+)", re.IGNORECASE)
83+
pattern_for_package_identifier = re.compile("slackclient/(\\S+)")
84+
85+
@async_test
86+
async def test_streams_a_short_message(self):
87+
streamer = await self.client.chat_stream(
88+
channel="C0123456789",
89+
thread_ts="123.000",
90+
recipient_team_id="T0123456789",
91+
recipient_user_id="U0123456789",
92+
)
93+
await streamer.append(markdown_text="nice!")
94+
await streamer.stop()
95+
96+
self.assertEqual(self.received_requests.get("/chat.startStream", 0), 1)
97+
self.assertEqual(self.received_requests.get("/chat.appendStream", 0), 0)
98+
self.assertEqual(self.received_requests.get("/chat.stopStream", 0), 1)
99+
100+
if hasattr(self.thread.server, "chat_stream_requests"):
101+
start_request = self.thread.server.chat_stream_requests.get("/chat.startStream", {})
102+
self.assertEqual(start_request.get("channel"), "C0123456789")
103+
self.assertEqual(start_request.get("thread_ts"), "123.000")
104+
self.assertEqual(start_request.get("recipient_team_id"), "T0123456789")
105+
self.assertEqual(start_request.get("recipient_user_id"), "U0123456789")
106+
107+
stop_request = self.thread.server.chat_stream_requests.get("/chat.stopStream", {})
108+
self.assertEqual(stop_request.get("channel"), "C0123456789")
109+
self.assertEqual(stop_request.get("ts"), "123.123")
110+
self.assertEqual(stop_request.get("markdown_text"), "nice!")
111+
112+
@async_test
113+
async def test_streams_a_long_message(self):
114+
streamer = await self.client.chat_stream(
115+
buffer_size=5,
116+
channel="C0123456789",
117+
recipient_team_id="T0123456789",
118+
recipient_user_id="U0123456789",
119+
thread_ts="123.000",
120+
)
121+
await streamer.append(markdown_text="**this messag")
122+
await streamer.append(markdown_text="e is")
123+
await streamer.append(markdown_text=" bold!")
124+
await streamer.append(markdown_text="*")
125+
await streamer.stop(markdown_text="*", token="xoxb-chat_stream_test_token")
126+
127+
self.assertEqual(self.received_requests.get("/chat.startStream", 0), 1)
128+
self.assertEqual(self.received_requests.get("/chat.appendStream", 0), 1)
129+
self.assertEqual(self.received_requests.get("/chat.stopStream", 0), 1)
130+
131+
if hasattr(self.thread.server, "chat_stream_requests"):
132+
start_request = self.thread.server.chat_stream_requests.get("/chat.startStream", {})
133+
self.assertEqual(start_request.get("channel"), "C0123456789")
134+
self.assertEqual(start_request.get("thread_ts"), "123.000")
135+
self.assertEqual(start_request.get("markdown_text"), "**this messag")
136+
self.assertEqual(start_request.get("recipient_team_id"), "T0123456789")
137+
self.assertEqual(start_request.get("recipient_user_id"), "U0123456789")
138+
139+
append_request = self.thread.server.chat_stream_requests.get("/chat.appendStream", {})
140+
self.assertEqual(append_request.get("channel"), "C0123456789")
141+
self.assertEqual(append_request.get("markdown_text"), "e is bold!")
142+
self.assertEqual(append_request.get("ts"), "123.123")
143+
144+
stop_request = self.thread.server.chat_stream_requests.get("/chat.stopStream", {})
145+
self.assertEqual(stop_request.get("channel"), "C0123456789")
146+
self.assertEqual(stop_request.get("markdown_text"), "**")
147+
self.assertEqual(stop_request.get("token"), "xoxb-chat_stream_test_token")
148+
self.assertEqual(stop_request.get("ts"), "123.123")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"ok": true
3+
}

0 commit comments

Comments
 (0)