Skip to content

Commit 7c460c1

Browse files
AndrewJGautAndrewJGaut
andauthored
Use websockets for json communication (#4490)
* first commit * Fix bugs * Fix more bugsz * slight change * minor bug fixes * It's working quite well now * Update to solve that last bug. I think it should work now * bug fix that was causing messages to be send to tall worker threads * Fix unittests issue and fix formatting * Minor change to increase robustness of message sending * Some more formatting changes * Add in worker auth and use wss rather than ws for websocket URLs to make it secure. Not yet tested b/c on plane and internet isn't good enough to run CodaLab (and download images and such) * Added in server auth with secret. Aslo still need to test (still on plane) * Fixed issues and got auth working. Now, I'll work on returning error codes so that we have proper tests * Add in tests for authentication functionality (for worker and server) * Slight cleanup to data sending code * Adding in ssl certification * Add in SSL stuff for worker; still testing on dev * Revert "Add in SSL stuff for worker; still testing on dev" This reverts commit 4eb3d7b. * Revert "Adding in ssl certification" This reverts commit cbd1505. * Fixed formatting * Very minor formatting change to ignore one line for MyPy * Another minor formatting change * add exponential backoff to see if that fixes dev issue * Added code to actually detect worker disconnections now so that some websockets will be invalidated * Make sockets get looped over in random order to help distribute load * Clean up ws-server and delete a Dataclass I was using previously * a few more minor changes * Make websocket locks more robust and improve error messaging * More permissible retries in case of other errors (e.g. like 1013). With exponential backoff, it's still not very aggressive at all * minor change to have a different error message if worker doesn't yet have any sockets open with ws-server * Rename send_json and send_json_message_with_sock * Rearrange worker_model to minimize diff * Final changes * Fix formatting * Minor changes to get auth working again and to robustly return an error if the client tries to connect to an invalid path. * Merge in master and make some minor changes --------- Co-authored-by: AndrewJGaut <stanford@DNa1c4533.SUNet>
1 parent 528db0f commit 7c460c1

14 files changed

Lines changed: 349 additions & 166 deletions

File tree

codalab/bin/ws_server.py

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,157 @@
1-
# Main entry point for CodaLab cl-ws-server.
1+
# Main entry point to the CodaLab Websocket Server.
2+
# The Websocket Server handles communication between the REST server and workers.
23
import argparse
34
import asyncio
5+
from collections import defaultdict
46
import logging
7+
import os
8+
import random
59
import re
10+
import time
611
from typing import Any, Dict
712
import websockets
13+
import threading
814

9-
logger = logging.getLogger(__name__)
10-
logger.setLevel(logging.WARNING)
11-
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')
12-
13-
worker_to_ws: Dict[str, Any] = {}
15+
from codalab.lib.codalab_manager import CodaLabManager
1416

1517

16-
async def rest_server_handler(websocket):
17-
"""Handles routes of the form: /main. This route is called by the rest-server
18-
whenever a worker needs to be pinged (to ask it to check in). The body of the
19-
message is the worker id to ping. This function sends a message to the worker
20-
with that worker id through an appropriate websocket.
18+
class TimedLock:
19+
"""A lock that gets automatically released after timeout_seconds.
2120
"""
22-
# Got a message from the rest server.
23-
worker_id = await websocket.recv()
24-
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")
2521

26-
try:
27-
worker_ws = worker_to_ws[worker_id]
28-
await worker_ws.send(worker_id)
29-
except KeyError:
30-
logger.error(f"Websocket not found for worker: {worker_id}")
22+
def __init__(self, timeout_seconds: float = 60):
23+
self._lock = threading.Lock()
24+
self._time_since_locked: float
25+
self._timeout: float = timeout_seconds
3126

27+
def acquire(self, blocking=True, timeout=-1):
28+
acquired = self._lock.acquire(blocking, timeout)
29+
if acquired:
30+
self._time_since_locked = time.time()
31+
return acquired
32+
33+
def locked(self):
34+
return self._lock.locked()
35+
36+
def release(self):
37+
self._lock.release()
38+
39+
def timeout(self):
40+
return time.time() - self._time_since_locked > self._timeout
41+
42+
def release_if_timeout(self):
43+
if self.locked() and self.timeout():
44+
self.release()
3245

33-
async def worker_handler(websocket, worker_id):
34-
"""Handles routes of the form: /worker/{id}. This route is called when
35-
a worker first connects to the ws-server, creating a connection that can
36-
be used to ask the worker to check-in later.
37-
"""
38-
# runs on worker connect
39-
worker_to_ws[worker_id] = websocket
40-
logger.warning(f"Connected to worker {worker_id}!")
4146

47+
worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict(
48+
dict
49+
) # Maps worker ID to socket ID to websocket
50+
worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict(
51+
dict
52+
) # Maps worker ID to socket ID to lock
53+
ACK = b'a'
54+
logger = logging.getLogger(__name__)
55+
manager = CodaLabManager()
56+
bundle_model = manager.model()
57+
worker_model = manager.worker_model()
58+
server_secret = os.getenv("CODALAB_SERVER_SECRET")
59+
60+
61+
async def send_to_worker_handler(server_websocket, worker_id):
62+
"""Handles routes of the form: /send_to_worker/{worker_id}. This route is called by
63+
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
64+
"""
65+
# Authenticate server.
66+
received_secret = await server_websocket.recv()
67+
if received_secret != server_secret:
68+
logger.warning("Server unable to authenticate.")
69+
await server_websocket.close(1008, "Server unable to authenticate.")
70+
return
71+
72+
# Check if any websockets available
73+
if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0:
74+
logger.warning(f"No websockets currently available for worker {worker_id}")
75+
await server_websocket.close(
76+
1011, f"No websockets currently available for worker {worker_id}"
77+
)
78+
return
79+
80+
# Send message from server to worker.
81+
for socket_id, worker_websocket in random.sample(
82+
worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id])
83+
):
84+
if worker_to_lock[worker_id][socket_id].acquire(blocking=False):
85+
data = await server_websocket.recv()
86+
await worker_websocket.send(data)
87+
await server_websocket.send(ACK)
88+
worker_to_lock[worker_id][socket_id].release()
89+
return
90+
91+
logger.warning(f"All websockets for worker {worker_id} are currently busy.")
92+
await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.")
93+
94+
95+
async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None:
96+
"""Handles routes of the form: /worker_connect/{worker_id}/{socket_id}.
97+
This route is called when a worker first connects to the ws-server, creating
98+
a connection that can be used to ask the worker to check-in later.
99+
"""
100+
# Authenticate worker.
101+
access_token = await websocket.recv()
102+
user_id = worker_model.get_user_id_for_worker(worker_id=worker_id)
103+
authenticated = bundle_model.access_token_exists_for_user(
104+
'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible.
105+
)
106+
logger.error(f"AUTHENTICATED: {authenticated}")
107+
if not authenticated:
108+
logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.")
109+
await websocket.close(
110+
1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate."
111+
)
112+
return
113+
114+
# Establish a connection with worker and keep it alive.
115+
worker_to_ws[worker_id][socket_id] = websocket
116+
worker_to_lock[worker_id][socket_id] = TimedLock()
117+
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")
42118
while True:
43119
try:
44120
await asyncio.wait_for(websocket.recv(), timeout=60)
121+
worker_to_lock[worker_id][
122+
socket_id
123+
].release_if_timeout() # Failsafe in case not released
45124
except asyncio.futures.TimeoutError:
46125
pass
47126
except websockets.exceptions.ConnectionClosed:
48-
logger.error(f"Socket connection closed with worker {worker_id}.")
127+
logger.warning(f"Socket connection closed with worker {worker_id}.")
49128
break
50-
51-
52-
ROUTES = (
53-
(r'^.*/main$', rest_server_handler),
54-
(r'^.*/worker/(.+)$', worker_handler),
55-
)
129+
del worker_to_ws[worker_id][socket_id]
130+
del worker_to_lock[worker_id][socket_id]
131+
logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections")
56132

57133

58134
async def ws_handler(websocket, *args):
59135
"""Handler for websocket connections. Routes websockets to the appropriate
60136
route handler defined in ROUTES."""
61-
logger.warning(f"websocket handler, path: {websocket.path}.")
137+
ROUTES = (
138+
(r'^.*/send_to_worker/(.+)$', send_to_worker_handler),
139+
(r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler),
140+
)
141+
logger.info(f"websocket handler, path: {websocket.path}.")
62142
for (pattern, handler) in ROUTES:
63143
match = re.match(pattern, websocket.path)
64144
if match:
65145
return await handler(websocket, *match.groups())
66-
assert False
146+
return await websocket.close(1011, f"Path {websocket.path} is not valid.")
67147

68148

69149
async def async_main():
70150
"""Main function that runs the websocket server."""
71151
parser = argparse.ArgumentParser()
72-
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
152+
parser.add_argument(
153+
'--port', help='Port to run the server on.', type=int, required=False, default=2901
154+
)
73155
args = parser.parse_args()
74156
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
75157
async with websockets.serve(ws_handler, "0.0.0.0", args.port):

codalab/lib/codalab_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ def ws_server(self):
248248
ws_port = self.config['ws-server']['ws_port']
249249
return f"ws://ws-server:{ws_port}"
250250

251+
@property # type: ignore
252+
@cached
253+
def server_secret(self):
254+
return os.getenv("CODALAB_SERVER_SECRET")
255+
251256
@property # type: ignore
252257
@cached
253258
def worker_socket_dir(self):
@@ -380,7 +385,9 @@ def model(self):
380385

381386
@cached
382387
def worker_model(self):
383-
return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server)
388+
return WorkerModel(
389+
self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret
390+
)
384391

385392
@cached
386393
def upload_manager(self):

codalab/lib/download_manager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
137137
read_args = {'type': 'get_target_info', 'depth': depth}
138138
self._send_read_message(worker, response_socket_id, target, read_args)
139139
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
140-
result = self._worker_model.get_json_message(sock, 60)
140+
result = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
141141
if result is None: # dead workers are a fact of life now
142142
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
143143
raise NotFoundError(
@@ -365,9 +365,7 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
365365
'path': target.subpath,
366366
'read_args': read_args,
367367
}
368-
if not self._worker_model.send_json_message(
369-
worker['socket_id'], worker['worker_id'], message, 60
370-
): # dead workers are a fact of life now
368+
if not self._worker_model.send_json_message(message, worker['worker_id']):
371369
logging.info('Unable to reach worker')
372370

373371
def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
@@ -378,21 +376,19 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
378376
'port': port,
379377
'message': message,
380378
}
381-
if not self._worker_model.send_json_message(
382-
worker['socket_id'], worker['worker_id'], message, 60
383-
): # dead workers are a fact of life now
379+
if not self._worker_model.send_json_message(message, worker['worker_id']):
384380
logging.info('Unable to reach worker')
385381

386382
def _get_read_response_stream(self, response_socket_id):
387383
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
388-
header_message = self._worker_model.get_json_message(sock, 60)
384+
header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
389385
precondition(header_message is not None, 'Unable to reach worker')
390386
if 'error_code' in header_message:
391387
raise http_error_to_exception(
392388
header_message['error_code'], header_message['error_message']
393389
)
394390

395-
fileobj = self._worker_model.get_stream(sock, 60)
391+
fileobj = self._worker_model.recv_stream(sock, 60)
396392
precondition(fileobj is not None, 'Unable to reach worker')
397393
return fileobj
398394

codalab/model/bundle_model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None):
28812881

28822882
return OAuth2Token(self, **row)
28832883

2884-
def find_oauth2_token(self, client_id, user_id, expires_after):
2884+
def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()):
28852885
with self.engine.begin() as connection:
28862886
row = connection.execute(
28872887
select([oauth2_token])
@@ -2900,6 +2900,25 @@ def find_oauth2_token(self, client_id, user_id, expires_after):
29002900

29012901
return OAuth2Token(self, **row)
29022902

2903+
def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool:
2904+
"""Check that the provided access_token exists in the database for the provided user_id.
2905+
"""
2906+
with self.engine.begin() as connection:
2907+
row = connection.execute(
2908+
select([oauth2_token])
2909+
.where(
2910+
and_(
2911+
oauth2_token.c.client_id == client_id,
2912+
oauth2_token.c.user_id == user_id,
2913+
oauth2_token.c.access_token == access_token,
2914+
oauth2_token.c.expires > datetime.datetime.utcnow(),
2915+
)
2916+
)
2917+
.limit(1)
2918+
).fetchone()
2919+
2920+
return row is not None
2921+
29032922
def save_oauth2_token(self, token):
29042923
with self.engine.begin() as connection:
29052924
result = connection.execute(oauth2_token.insert().values(token.columns))

0 commit comments

Comments
 (0)