Skip to content

Commit cba64c1

Browse files
committed
fix: override the token if provided in append or stop methods
1 parent 73e46fd commit cba64c1

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

slack_sdk/web/async_chat_stream.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
self._buffer = ""
5959
self._state = "starting"
6060
self._stream_ts: Optional[str] = None
61+
self._token: Optional[str] = kwargs.get("token")
6162
self._buffer_size = buffer_size
6263

6364
async def append(
@@ -80,6 +81,8 @@ async def append(
8081
"""
8182
if self._state == "completed":
8283
raise e.SlackRequestError(f"Cannot append to stream: stream state is {self._state}")
84+
if kwargs.get("token"):
85+
self._token = kwargs.pop("token")
8386
self._buffer += markdown_text
8487
if len(self._buffer) >= self._buffer_size:
8588
return await self._flush_buffer(**kwargs)
@@ -118,26 +121,25 @@ async def stop(
118121
"""
119122
if self._state == "completed":
120123
raise e.SlackRequestError(f"Cannot stop stream: stream state is {self._state}")
124+
if kwargs.get("token"):
125+
self._token = kwargs.pop("token")
121126
if markdown_text:
122127
self._buffer += markdown_text
123128
if not self._stream_ts:
124129
response = await self._client.chat_startStream(
125130
**self._stream_args,
126-
**kwargs,
131+
token=self._token,
127132
)
128133
if not response.get("ts"):
129134
raise e.SlackRequestError("Failed to stop stream: stream not started")
130135
self._stream_ts = str(response["ts"])
131136
self._state = "in_progress"
132-
133-
print(f"_stream_args: {self._stream_args}\n") # todo
134-
print(f"_buffer: {self._buffer}\n")
135-
136137
response = await self._client.chat_stopStream(
138+
token=self._token,
137139
channel=self._stream_args["channel"],
138140
ts=self._stream_ts,
139141
blocks=blocks,
140-
markdown_text=self._buffer + (markdown_text if markdown_text is not None else ""),
142+
markdown_text=self._buffer,
141143
metadata=metadata,
142144
**kwargs,
143145
)
@@ -149,17 +151,19 @@ async def _flush_buffer(self, **kwargs) -> AsyncSlackResponse:
149151
if not self._stream_ts:
150152
response = await self._client.chat_startStream(
151153
**self._stream_args,
152-
markdown_text=self._buffer,
154+
token=self._token,
153155
**kwargs,
156+
markdown_text=self._buffer,
154157
)
155158
self._stream_ts = response.get("ts")
156159
self._state = "in_progress"
157160
else:
158161
response = await self._client.chat_appendStream(
162+
token=self._token,
159163
channel=self._stream_args["channel"],
160164
ts=self._stream_ts,
161-
markdown_text=self._buffer,
162165
**kwargs,
166+
markdown_text=self._buffer,
163167
)
164168

165169
self._buffer = ""

slack_sdk/web/chat_stream.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
self._buffer = ""
5959
self._state = "starting"
6060
self._stream_ts: Optional[str] = None
61+
self._token: Optional[str] = kwargs.get("token")
6162
self._buffer_size = buffer_size
6263

6364
def append(
@@ -80,6 +81,8 @@ def append(
8081
"""
8182
if self._state == "completed":
8283
raise e.SlackRequestError(f"Cannot append to stream: stream state is {self._state}")
84+
if kwargs.get("token"):
85+
self._token = kwargs.pop("token")
8386
self._buffer += markdown_text
8487
if len(self._buffer) >= self._buffer_size:
8588
return self._flush_buffer(**kwargs)
@@ -118,18 +121,21 @@ def stop(
118121
"""
119122
if self._state == "completed":
120123
raise e.SlackRequestError(f"Cannot stop stream: stream state is {self._state}")
124+
if kwargs.get("token"):
125+
self._token = kwargs.pop("token")
121126
if markdown_text:
122127
self._buffer += markdown_text
123128
if not self._stream_ts:
124129
response = self._client.chat_startStream(
125130
**self._stream_args,
126-
**kwargs,
131+
token=self._token,
127132
)
128133
if not response.get("ts"):
129134
raise e.SlackRequestError("Failed to stop stream: stream not started")
130135
self._stream_ts = str(response["ts"])
131136
self._state = "in_progress"
132137
response = self._client.chat_stopStream(
138+
token=self._token,
133139
channel=self._stream_args["channel"],
134140
ts=self._stream_ts,
135141
blocks=blocks,
@@ -145,17 +151,19 @@ def _flush_buffer(self, **kwargs) -> SlackResponse:
145151
if not self._stream_ts:
146152
response = self._client.chat_startStream(
147153
**self._stream_args,
148-
markdown_text=self._buffer,
154+
token=self._token,
149155
**kwargs,
156+
markdown_text=self._buffer,
150157
)
151158
self._stream_ts = response.get("ts")
152159
self._state = "in_progress"
153160
else:
154161
response = self._client.chat_appendStream(
162+
token=self._token,
155163
channel=self._stream_args["channel"],
156164
ts=self._stream_ts,
157-
markdown_text=self._buffer,
158165
**kwargs,
166+
markdown_text=self._buffer,
159167
)
160168

161169
self._buffer = ""

tests/slack_sdk/web/test_chat_stream.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def _handle(self):
1818

1919
# Standard auth and validation from parent
2020
if self.is_valid_token() and self.is_valid_user_agent():
21+
token = self.headers["authorization"].split(" ")[1]
2122
parsed_path = urlparse(self.path)
2223
len_header = self.headers.get("Content-Length") or 0
2324
content_len = int(len_header)
@@ -40,13 +41,13 @@ def _handle(self):
4041
if self.path in ["/chat.startStream", "/chat.appendStream", "/chat.stopStream"] and request_body:
4142
if not hasattr(self.server, "chat_stream_requests"):
4243
self.server.chat_stream_requests = {}
43-
self.server.chat_stream_requests[self.path] = request_body
44-
45-
# Get token pattern for response file
46-
header = self.headers["authorization"]
47-
pattern = str(header).split("xoxb-", 1)[1]
44+
self.server.chat_stream_requests[self.path] = {
45+
"token": token,
46+
**request_body,
47+
}
4848

4949
# Load response file
50+
pattern = str(token).split("xoxb-", 1)[1]
5051
with open(f"tests/slack_sdk_fixture/web_response_{pattern}.json") as file:
5152
body = json.load(file)
5253

@@ -93,6 +94,7 @@ def test_streams_a_short_message(self):
9394
self.assertEqual(self.received_requests.get("/chat.startStream", 0), 1)
9495
self.assertEqual(self.received_requests.get("/chat.appendStream", 0), 0)
9596
self.assertEqual(self.received_requests.get("/chat.stopStream", 0), 1)
97+
9698
if hasattr(self.thread.server, "chat_stream_requests"):
9799
start_request = self.thread.server.chat_stream_requests.get("/chat.startStream", {})
98100
self.assertEqual(start_request.get("channel"), "C0123456789")
@@ -117,7 +119,7 @@ def test_streams_a_long_message(self):
117119
streamer.append(markdown_text="e is")
118120
streamer.append(markdown_text=" bold!")
119121
streamer.append(markdown_text="*")
120-
streamer.stop(markdown_text="*")
122+
streamer.stop(markdown_text="*", token="xoxb-chat_stream_test_token")
121123

122124
self.assertEqual(self.received_requests.get("/chat.startStream", 0), 1)
123125
self.assertEqual(self.received_requests.get("/chat.appendStream", 0), 1)
@@ -139,4 +141,5 @@ def test_streams_a_long_message(self):
139141
stop_request = self.thread.server.chat_stream_requests.get("/chat.stopStream", {})
140142
self.assertEqual(stop_request.get("channel"), "C0123456789")
141143
self.assertEqual(stop_request.get("markdown_text"), "**")
144+
self.assertEqual(stop_request.get("token"), "xoxb-chat_stream_test_token")
142145
self.assertEqual(stop_request.get("ts"), "123.123")

0 commit comments

Comments
 (0)