|
1 | | -# Main entry point for CodaLab cl-ws-server. |
| 1 | +# Main entry point to the CodaLab Websocket Server. |
| 2 | +# The Websocket Server handles communication between the REST server and workers. |
2 | 3 | import argparse |
3 | 4 | import asyncio |
| 5 | +from collections import defaultdict |
4 | 6 | import logging |
| 7 | +import os |
| 8 | +import random |
5 | 9 | import re |
| 10 | +import time |
6 | 11 | from typing import Any, Dict |
7 | 12 | import websockets |
| 13 | +import threading |
8 | 14 |
|
9 | | -logger = logging.getLogger(__name__) |
10 | | -logger.setLevel(logging.WARNING) |
11 | | -logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d') |
12 | | - |
13 | | -worker_to_ws: Dict[str, Any] = {} |
| 15 | +from codalab.lib.codalab_manager import CodaLabManager |
14 | 16 |
|
15 | 17 |
|
16 | | -async def rest_server_handler(websocket): |
17 | | - """Handles routes of the form: /main. This route is called by the rest-server |
18 | | - whenever a worker needs to be pinged (to ask it to check in). The body of the |
19 | | - message is the worker id to ping. This function sends a message to the worker |
20 | | - with that worker id through an appropriate websocket. |
| 18 | +class TimedLock: |
| 19 | + """A lock that gets automatically released after timeout_seconds. |
21 | 20 | """ |
22 | | - # Got a message from the rest server. |
23 | | - worker_id = await websocket.recv() |
24 | | - logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.") |
25 | 21 |
|
26 | | - try: |
27 | | - worker_ws = worker_to_ws[worker_id] |
28 | | - await worker_ws.send(worker_id) |
29 | | - except KeyError: |
30 | | - logger.error(f"Websocket not found for worker: {worker_id}") |
| 22 | + def __init__(self, timeout_seconds: float = 60): |
| 23 | + self._lock = threading.Lock() |
| 24 | + self._time_since_locked: float |
| 25 | + self._timeout: float = timeout_seconds |
31 | 26 |
|
| 27 | + def acquire(self, blocking=True, timeout=-1): |
| 28 | + acquired = self._lock.acquire(blocking, timeout) |
| 29 | + if acquired: |
| 30 | + self._time_since_locked = time.time() |
| 31 | + return acquired |
| 32 | + |
| 33 | + def locked(self): |
| 34 | + return self._lock.locked() |
| 35 | + |
| 36 | + def release(self): |
| 37 | + self._lock.release() |
| 38 | + |
| 39 | + def timeout(self): |
| 40 | + return time.time() - self._time_since_locked > self._timeout |
| 41 | + |
| 42 | + def release_if_timeout(self): |
| 43 | + if self.locked() and self.timeout(): |
| 44 | + self.release() |
32 | 45 |
|
33 | | -async def worker_handler(websocket, worker_id): |
34 | | - """Handles routes of the form: /worker/{id}. This route is called when |
35 | | - a worker first connects to the ws-server, creating a connection that can |
36 | | - be used to ask the worker to check-in later. |
37 | | - """ |
38 | | - # runs on worker connect |
39 | | - worker_to_ws[worker_id] = websocket |
40 | | - logger.warning(f"Connected to worker {worker_id}!") |
41 | 46 |
|
| 47 | +worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict( |
| 48 | + dict |
| 49 | +) # Maps worker ID to socket ID to websocket |
| 50 | +worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict( |
| 51 | + dict |
| 52 | +) # Maps worker ID to socket ID to lock |
| 53 | +ACK = b'a' |
| 54 | +logger = logging.getLogger(__name__) |
| 55 | +manager = CodaLabManager() |
| 56 | +bundle_model = manager.model() |
| 57 | +worker_model = manager.worker_model() |
| 58 | +server_secret = os.getenv("CODALAB_SERVER_SECRET") |
| 59 | + |
| 60 | + |
| 61 | +async def send_to_worker_handler(server_websocket, worker_id): |
| 62 | + """Handles routes of the form: /send_to_worker/{worker_id}. This route is called by |
| 63 | + the rest-server or bundle-manager when either wants to send a message/stream to the worker. |
| 64 | + """ |
| 65 | + # Authenticate server. |
| 66 | + received_secret = await server_websocket.recv() |
| 67 | + if received_secret != server_secret: |
| 68 | + logger.warning("Server unable to authenticate.") |
| 69 | + await server_websocket.close(1008, "Server unable to authenticate.") |
| 70 | + return |
| 71 | + |
| 72 | + # Check if any websockets available |
| 73 | + if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0: |
| 74 | + logger.warning(f"No websockets currently available for worker {worker_id}") |
| 75 | + await server_websocket.close( |
| 76 | + 1011, f"No websockets currently available for worker {worker_id}" |
| 77 | + ) |
| 78 | + return |
| 79 | + |
| 80 | + # Send message from server to worker. |
| 81 | + for socket_id, worker_websocket in random.sample( |
| 82 | + worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id]) |
| 83 | + ): |
| 84 | + if worker_to_lock[worker_id][socket_id].acquire(blocking=False): |
| 85 | + data = await server_websocket.recv() |
| 86 | + await worker_websocket.send(data) |
| 87 | + await server_websocket.send(ACK) |
| 88 | + worker_to_lock[worker_id][socket_id].release() |
| 89 | + return |
| 90 | + |
| 91 | + logger.warning(f"All websockets for worker {worker_id} are currently busy.") |
| 92 | + await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.") |
| 93 | + |
| 94 | + |
| 95 | +async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None: |
| 96 | + """Handles routes of the form: /worker_connect/{worker_id}/{socket_id}. |
| 97 | + This route is called when a worker first connects to the ws-server, creating |
| 98 | + a connection that can be used to ask the worker to check-in later. |
| 99 | + """ |
| 100 | + # Authenticate worker. |
| 101 | + access_token = await websocket.recv() |
| 102 | + user_id = worker_model.get_user_id_for_worker(worker_id=worker_id) |
| 103 | + authenticated = bundle_model.access_token_exists_for_user( |
| 104 | + 'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible. |
| 105 | + ) |
| 106 | + logger.error(f"AUTHENTICATED: {authenticated}") |
| 107 | + if not authenticated: |
| 108 | + logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.") |
| 109 | + await websocket.close( |
| 110 | + 1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate." |
| 111 | + ) |
| 112 | + return |
| 113 | + |
| 114 | + # Establish a connection with worker and keep it alive. |
| 115 | + worker_to_ws[worker_id][socket_id] = websocket |
| 116 | + worker_to_lock[worker_id][socket_id] = TimedLock() |
| 117 | + logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections") |
42 | 118 | while True: |
43 | 119 | try: |
44 | 120 | await asyncio.wait_for(websocket.recv(), timeout=60) |
| 121 | + worker_to_lock[worker_id][ |
| 122 | + socket_id |
| 123 | + ].release_if_timeout() # Failsafe in case not released |
45 | 124 | except asyncio.futures.TimeoutError: |
46 | 125 | pass |
47 | 126 | except websockets.exceptions.ConnectionClosed: |
48 | | - logger.error(f"Socket connection closed with worker {worker_id}.") |
| 127 | + logger.warning(f"Socket connection closed with worker {worker_id}.") |
49 | 128 | break |
50 | | - |
51 | | - |
52 | | -ROUTES = ( |
53 | | - (r'^.*/main$', rest_server_handler), |
54 | | - (r'^.*/worker/(.+)$', worker_handler), |
55 | | -) |
| 129 | + del worker_to_ws[worker_id][socket_id] |
| 130 | + del worker_to_lock[worker_id][socket_id] |
| 131 | + logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections") |
56 | 132 |
|
57 | 133 |
|
58 | 134 | async def ws_handler(websocket, *args): |
59 | 135 | """Handler for websocket connections. Routes websockets to the appropriate |
60 | 136 | route handler defined in ROUTES.""" |
61 | | - logger.warning(f"websocket handler, path: {websocket.path}.") |
| 137 | + ROUTES = ( |
| 138 | + (r'^.*/send_to_worker/(.+)$', send_to_worker_handler), |
| 139 | + (r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler), |
| 140 | + ) |
| 141 | + logger.info(f"websocket handler, path: {websocket.path}.") |
62 | 142 | for (pattern, handler) in ROUTES: |
63 | 143 | match = re.match(pattern, websocket.path) |
64 | 144 | if match: |
65 | 145 | return await handler(websocket, *match.groups()) |
66 | | - assert False |
| 146 | + return await websocket.close(1011, f"Path {websocket.path} is not valid.") |
67 | 147 |
|
68 | 148 |
|
69 | 149 | async def async_main(): |
70 | 150 | """Main function that runs the websocket server.""" |
71 | 151 | parser = argparse.ArgumentParser() |
72 | | - parser.add_argument('--port', help='Port to run the server on.', type=int, required=True) |
| 152 | + parser.add_argument( |
| 153 | + '--port', help='Port to run the server on.', type=int, required=False, default=2901 |
| 154 | + ) |
73 | 155 | args = parser.parse_args() |
74 | 156 | logging.debug(f"Running ws-server on 0.0.0.0:{args.port}") |
75 | 157 | async with websockets.serve(ws_handler, "0.0.0.0", args.port): |
|
0 commit comments