Skip to content

Commit 04a59ab

Browse files
authored
Fixing 415 - Converting the typing function to async (#446)
* WIP: simplifying how callbacks are executed and making ping/typing methods asynchronous. * Cleaning up RTM changes and updating tests * More clean up * More code clean up * Removing comment
2 parents bebc13e + ebd4ddc commit 04a59ab

6 files changed

Lines changed: 53 additions & 88 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ coverage.xml
2424
.cache
2525
.pytest_cache/
2626
.python-version
27-
pip
27+
pip
28+
.mypy_cache/

slack/rtm/client.py

Lines changed: 31 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
import os
55
import logging
66
import random
7-
import json
87
import collections
9-
import functools
108
import inspect
119
import signal
12-
import concurrent.futures
13-
from typing import Optional, Callable
10+
from typing import Optional, Callable, DefaultDict
1411
from ssl import SSLContext
1512

1613
# ThirdParty Imports
@@ -69,13 +66,12 @@ class RTMClient(object):
6966
Example:
7067
```python
7168
import os
72-
import slack
69+
from slack import RTMClient
7370
74-
@slack.RTMClient.run_on(event='message')
71+
@RTMClient.run_on(event="message")
7572
def say_hello(**payload):
7673
data = payload['data']
7774
web_client = payload['web_client']
78-
rtm_client = payload['rtm_client']
7975
if 'Hello' in data['text']:
8076
channel_id = data['channel']
8177
thread_ts = data['ts']
@@ -88,7 +84,7 @@ def say_hello(**payload):
8884
)
8985
9086
slack_token = os.environ["SLACK_API_TOKEN"]
91-
rtm_client = slack.RTMClient(token=slack_token)
87+
rtm_client = RTMClient(token=slack_token)
9288
rtm_client.start()
9389
```
9490
@@ -102,7 +98,7 @@ def say_hello(**payload):
10298
removed at anytime.
10399
"""
104100

105-
_callbacks = collections.defaultdict(list)
101+
_callbacks: DefaultDict = collections.defaultdict(list)
106102

107103
def __init__(
108104
self,
@@ -141,11 +137,7 @@ def run_on(*, event: str):
141137
"""A decorator to store and link a callback to an event."""
142138

143139
def decorator(callback):
144-
@functools.wraps(callback)
145-
def decorator_wrapper():
146-
RTMClient.on(event=event, callback=callback)
147-
148-
return decorator_wrapper()
140+
RTMClient.on(event=event, callback=callback)
149141

150142
return decorator
151143

@@ -196,7 +188,7 @@ def start(self) -> asyncio.Future:
196188

197189
future = asyncio.ensure_future(self._connect_and_read(), loop=self._event_loop)
198190

199-
if self.run_async or self._event_loop.is_running():
191+
if self.run_async:
200192
return future
201193

202194
return self._event_loop.run_until_complete(future)
@@ -231,17 +223,19 @@ def send_over_websocket(self, *, payload: dict):
231223
Raises:
232224
SlackClientNotConnectedError: Websocket connection is closed.
233225
"""
226+
return asyncio.ensure_future(self._send_json(payload))
227+
228+
async def _send_json(self, payload):
234229
if self._websocket is None or self._event_loop is None:
235230
raise client_err.SlackClientNotConnectedError(
236231
"Websocket connection is closed."
237232
)
238233
if "id" not in payload:
239234
payload["id"] = self._next_msg_id()
240-
asyncio.ensure_future(
241-
self._websocket.send_str(json.dumps(payload)), loop=self._event_loop
242-
)
243235

244-
def ping(self):
236+
return await self._websocket.send_json(payload)
237+
238+
async def ping(self):
245239
"""Sends a ping message over the websocket to Slack.
246240
247241
Not all web browsers support the WebSocket ping spec,
@@ -251,9 +245,9 @@ def ping(self):
251245
SlackClientNotConnectedError: Websocket connection is closed.
252246
"""
253247
payload = {"id": self._next_msg_id(), "type": "ping"}
254-
self.send_over_websocket(payload=payload)
248+
await self._send_json(payload=payload)
255249

256-
def typing(self, *, channel: str):
250+
async def typing(self, *, channel: str):
257251
"""Sends a typing indicator to the specified channel.
258252
259253
This indicates that this app is currently
@@ -266,7 +260,7 @@ def typing(self, *, channel: str):
266260
SlackClientNotConnectedError: Websocket connection is closed.
267261
"""
268262
payload = {"id": self._next_msg_id(), "type": "typing", "channel": channel}
269-
self.send_over_websocket(payload=payload)
263+
await self._send_json(payload=payload)
270264

271265
@staticmethod
272266
def _validate_callback(callback):
@@ -307,9 +301,9 @@ def _next_msg_id(self):
307301
return self._last_message_id
308302

309303
async def _connect_and_read(self):
310-
"""Retreives and connects to Slack's RTM API.
304+
"""Retreives the WS url and connects to Slack's RTM API.
311305
312-
Makes an authenticated call to Slack's RTM API to retrieve
306+
Makes an authenticated call to Slack's Web API to retrieve
313307
a websocket URL. Then connects to the message server and
314308
reads event messages as they come in.
315309
@@ -338,15 +332,15 @@ async def _connect_and_read(self):
338332
) as websocket:
339333
self._logger.debug("The Websocket connection has been opened.")
340334
self._websocket = websocket
341-
self._dispatch_event(event="open", data=data)
335+
await self._dispatch_event(event="open", data=data)
342336
await self._read_messages()
343337
except (
344338
client_err.SlackClientNotConnectedError,
345339
client_err.SlackApiError,
346340
# TODO: Catch websocket exceptions thrown by aiohttp.
347341
) as exception:
348342
self._logger.debug(str(exception))
349-
self._dispatch_event(event="error", data=exception)
343+
await self._dispatch_event(event="error", data=exception)
350344
if self.auto_reconnect and not self._stopped:
351345
await self._wait_exponentially(exception)
352346
continue
@@ -366,11 +360,11 @@ async def _read_messages(self):
366360
if message.type == aiohttp.WSMsgType.TEXT:
367361
payload = message.json()
368362
event = payload.pop("type", "Unknown")
369-
self._dispatch_event(event, data=payload)
363+
await self._dispatch_event(event, data=payload)
370364
elif message.type == aiohttp.WSMsgType.ERROR:
371365
break
372366

373-
def _dispatch_event(self, event, data=None):
367+
async def _dispatch_event(self, event, data=None):
374368
"""Dispatches the event and executes any associated callbacks.
375369
376370
Note: To prevent the app from crashing due to callback errors. We
@@ -399,52 +393,19 @@ def _dispatch_event(self, event, data=None):
399393
# Don't run callbacks if client was stopped unless they're close/error callbacks.
400394
break
401395

402-
if self.run_async:
403-
self._execute_callback_async(callback, data)
396+
if inspect.iscoroutinefunction(callback):
397+
await callback(
398+
rtm_client=self, web_client=self._web_client, data=data
399+
)
404400
else:
405-
self._execute_callback(callback, data)
401+
callback(rtm_client=self, web_client=self._web_client, data=data)
406402
except Exception as err:
407403
name = callback.__name__
408404
module = callback.__module__
409405
msg = f"When calling '#{name}()' in the '{module}' module the following error was raised: {err}"
410406
self._logger.error(msg)
411407
raise
412408

413-
def _execute_callback_async(self, callback, data):
414-
"""Execute the callback asynchronously.
415-
416-
If the callback is not a coroutine, convert it.
417-
418-
Note: The WebClient passed into the callback is running in "async" mode.
419-
This means all responses will be futures.
420-
"""
421-
if asyncio.iscoroutine(callback):
422-
asyncio.ensure_future(
423-
callback(rtm_client=self, web_client=self._web_client, data=data)
424-
)
425-
else:
426-
asyncio.ensure_future(
427-
asyncio.coroutine(callback)(
428-
rtm_client=self, web_client=self._web_client, data=data
429-
)
430-
)
431-
432-
def _execute_callback(self, callback, data):
433-
"""Execute the callback in another thread. Wait for and return the results."""
434-
web_client = WebClient(
435-
token=self.token, base_url=self.base_url, ssl=self.ssl, proxy=self.proxy
436-
)
437-
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
438-
# Execute the callback on a separate thread,
439-
future = executor.submit(
440-
callback, rtm_client=self, web_client=web_client, data=data
441-
)
442-
443-
while future.running():
444-
pass
445-
446-
future.result()
447-
448409
async def _retreive_websocket_info(self):
449410
"""Retreives the WebSocket info from Slack.
450411
@@ -491,7 +452,7 @@ async def _wait_exponentially(self, exception, max_wait_time=300):
491452
"""Wait exponentially longer for each connection attempt.
492453
493454
Calculate the number of seconds to wait and then add
494-
a random number of milliseconds to avoid coincendental
455+
a random number of milliseconds to avoid coincidental
495456
synchronized client retries. Wait up to the maximium amount
496457
of wait time specified via 'max_wait_time'. However,
497458
if Slack returned how long to wait use that.
@@ -512,4 +473,6 @@ def _close_websocket(self):
512473
if callable(close_method):
513474
asyncio.ensure_future(close_method(), loop=self._event_loop)
514475
self._websocket = None
515-
self._dispatch_event(event="close")
476+
asyncio.ensure_future(
477+
self._dispatch_event(event="close"), loop=self._event_loop
478+
)

slack/web/base_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import sys
77
import logging
88
import asyncio
9+
from typing import Optional, Union
910
import inspect
1011

1112
# ThirdParty Imports
1213
import aiohttp
14+
from aiohttp import FormData
1315

1416
# Internal Imports
1517
from slack.web.slack_response import SlackResponse
@@ -25,7 +27,7 @@ def __init__(
2527
token,
2628
base_url=BASE_URL,
2729
timeout=30,
28-
loop=None,
30+
loop: Optional[asyncio.AbstractEventLoop] = None,
2931
ssl=None,
3032
proxy=None,
3133
run_async=False,
@@ -43,19 +45,19 @@ def __init__(
4345

4446
def _set_event_loop(self):
4547
if self.run_async:
46-
self._event_loop = asyncio.get_event_loop()
48+
return asyncio.get_event_loop()
4749
else:
4850
loop = asyncio.new_event_loop()
4951
asyncio.set_event_loop(loop)
50-
self._event_loop = loop
52+
return loop
5153

5254
def api_call(
5355
self,
5456
api_method: str,
5557
*,
5658
http_verb: str = "POST",
5759
files: dict = None,
58-
data: dict = None,
60+
data: Union[dict, FormData] = None,
5961
params: dict = None,
6062
json: dict = None,
6163
):
@@ -99,15 +101,15 @@ def api_call(
99101
"Authorization": "Bearer {}".format(self.token),
100102
}
101103
if files is not None:
102-
form_data = aiohttp.FormData()
104+
form_data = FormData()
103105
for k, v in files.items():
104106
if isinstance(v, str):
105107
with open(v, "rb") as fd:
106108
form_data.add_field(k, fd)
107109
else:
108110
form_data.add_field(k, v)
109111

110-
if data is not None:
112+
if isinstance(data, dict):
111113
for k, v in data.items():
112114
form_data.add_field(k, str(v))
113115

@@ -123,7 +125,7 @@ def api_call(
123125
}
124126

125127
if self._event_loop is None:
126-
self._set_event_loop()
128+
self._event_loop = self._set_event_loop()
127129

128130
future = asyncio.ensure_future(
129131
self._send(http_verb=http_verb, api_url=api_url, req_args=req_args),
@@ -196,7 +198,6 @@ async def _request(self, *, http_verb, api_url, req_args):
196198
"""
197199
if self.session and not self.session.closed:
198200
async with self.session.request(http_verb, api_url, **req_args) as res:
199-
self._logger.debug("Ran the request with existing session.")
200201
return {
201202
"data": await res.json(),
202203
"headers": res.headers,
@@ -206,7 +207,6 @@ async def _request(self, *, http_verb, api_url, req_args):
206207
loop=self._event_loop, timeout=aiohttp.ClientTimeout(total=self.timeout)
207208
) as session:
208209
async with session.request(http_verb, api_url, **req_args) as res:
209-
self._logger.debug("Ran the request with a new session.")
210210
return {
211211
"data": await res.json(),
212212
"headers": res.headers,

slack/web/client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,10 +656,9 @@ def files_upload(
656656

657657
if file:
658658
return self.api_call("files.upload", files={"file": file}, data=kwargs)
659-
elif content:
660-
data = kwargs.copy()
661-
data.update({"content": content})
662-
return self.api_call("files.upload", data=data)
659+
data = kwargs.copy()
660+
data.update({"content": content})
661+
return self.api_call("files.upload", data=data)
663662

664663
def groups_archive(self, *, channel: str, **kwargs) -> SlackResponse:
665664
"""Archives a private channel.

tests/rtm/test_rtm_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import collections
33
import unittest
44
from unittest import mock
5+
import asyncio
56

67
# Internal Imports
78
import slack
@@ -65,7 +66,8 @@ def invalid_cb():
6566

6667
def test_send_over_websocket_raises_when_not_connected(self):
6768
with self.assertRaises(e.SlackClientError) as context:
68-
self.client.send_over_websocket(payload={})
69+
loop = asyncio.get_event_loop()
70+
loop.run_until_complete(self.client.send_over_websocket(payload={}))
6971

7072
expected_error = "Websocket connection is closed."
7173
error = str(context.exception)

tests/rtm/test_rtm_client_functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def websocket_handler(self, request):
4141
async for msg in ws:
4242
await ws.send_json({"type": "message", "message_sent": msg.json()})
4343
finally:
44-
request.app["websockets"].discard(ws)
44+
request.app["websockets"].remove(ws)
4545
return ws
4646

4747
async def on_shutdown(self, app):
@@ -170,9 +170,9 @@ def check_message(**payload):
170170

171171
def test_ping_sends_expected_message(self, mock_rtm_response):
172172
@slack.RTMClient.run_on(event="open")
173-
def ping_message(**payload):
173+
async def ping_message(**payload):
174174
rtm_client = payload["rtm_client"]
175-
rtm_client.ping()
175+
await rtm_client.ping()
176176

177177
@slack.RTMClient.run_on(event="message")
178178
def check_message(**payload):
@@ -185,9 +185,9 @@ def check_message(**payload):
185185

186186
def test_typing_sends_expected_message(self, mock_rtm_response):
187187
@slack.RTMClient.run_on(event="open")
188-
def typing_message(**payload):
188+
async def typing_message(**payload):
189189
rtm_client = payload["rtm_client"]
190-
rtm_client.typing(channel="C01234567")
190+
await rtm_client.typing(channel="C01234567")
191191

192192
@slack.RTMClient.run_on(event="message")
193193
def check_message(**payload):

0 commit comments

Comments
 (0)