-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path__ws_wrapper.py
More file actions
348 lines (302 loc) · 12.7 KB
/
Copy path__ws_wrapper.py
File metadata and controls
348 lines (302 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import json
import queue
import threading
from collections.abc import Callable, Generator
from enum import Enum
from mistapi import APISession
from mistapi.__api_response import APIResponse as _APIResponse
from mistapi.__logger import logger as LOGGER
class TimerAction(Enum):
"""
TimerAction Enum for managing timer actions in WebSocketWrapper.
"""
START = "start"
STOP = "stop"
RESET = "reset"
class Timer(Enum):
"""
Timer Enum for specifying different timer types in WebSocketWrapper.
"""
TIMEOUT = "timeout"
FIRST_MESSAGE_TIMEOUT = "first_message_timeout"
MAX_DURATION = "max_duration"
class UtilResponse:
"""
Encapsulates the response from device utility functions.
Returned immediately by tool functions. When a WebSocket stream is
involved, data is collected in the background. Use ``receive()``,
``wait()``, or the ``on_message`` callback to consume results.
USAGE PATTERNS
-----------
Callback style (on_message passed at call time)::
response = ex.ping(session, site_id, device_id, host="8.8.8.8",
on_message=lambda msg: print(msg))
do_other_work()
response.wait()
print(response.ws_data)
Generator style::
response = ex.ping(session, site_id, device_id, host="8.8.8.8")
for msg in response.receive():
print(msg)
Context manager::
with ex.ping(session, site_id, device_id, host="8.8.8.8") as response:
for msg in response.receive():
print(msg)
Async await::
response = ex.ping(session, site_id, device_id, host="8.8.8.8")
await response
print(response.ws_data)
"""
def __init__(
self,
api_response: _APIResponse,
) -> None:
self.trigger_api_response = api_response
self.ws_required: bool = False
self.ws_data: list[str] = []
self.ws_raw_events: list[str] = []
self._queue: queue.Queue[str | None] = queue.Queue()
self._closed = threading.Event()
self._closed.set() # default: done (no WS to wait for)
self._disconnect_fn: Callable[[], None] | None = None
@property
def done(self) -> bool:
"""True if data collection is complete (or no WS was needed)."""
return self._closed.is_set()
def wait(self, timeout: float | None = None) -> "UtilResponse":
"""Block until data collection is complete. Returns self."""
self._closed.wait(timeout=timeout)
return self
def receive(self) -> Generator[str, None, None]:
"""
Blocking generator that yields each processed message as it arrives.
Mirrors ``_MistWebsocket.receive()``. Exits cleanly when the
WebSocket connection closes or ``disconnect()`` is called.
"""
while True:
try:
item = self._queue.get(timeout=1)
except queue.Empty:
if self._closed.is_set() and self._queue.empty():
break
continue
if item is None:
break
yield item
def disconnect(self) -> None:
"""Stop the WebSocket connection early."""
if self._disconnect_fn:
self._disconnect_fn()
def __enter__(self) -> "UtilResponse":
return self
def __exit__(self, *args) -> None:
self.disconnect()
def __await__(self):
"""Allow ``result = await response`` in async contexts."""
import asyncio
async def _await_impl():
await asyncio.to_thread(self._closed.wait)
return self
return _await_impl().__await__()
class WebSocketWrapper:
"""
A wrapper class for managing WebSocket connections and events.
This class provides a simplified interface for connecting to WebSocket channels,
handling messages, and managing connection timeouts.
"""
def __init__(
self,
apissession: APISession,
util_response: UtilResponse,
timeout: int = 10,
max_duration: int = 60,
on_message: Callable[[dict], None] | None = None,
) -> None:
self.apissession = apissession
self.util_response = util_response
self.timers = {
Timer.TIMEOUT.value: {
"thread": None,
"duration": timeout,
},
Timer.FIRST_MESSAGE_TIMEOUT.value: {
"thread": None,
"duration": 30,
},
Timer.MAX_DURATION.value: {
"thread": None,
"duration": max_duration,
},
}
self.received_messages = 0
self.data = []
self.raw_events = []
self.ws = None
self.session_id: str | None = None
self.capture_id: str | None = None
self._on_message_cb = on_message
LOGGER.debug(
"trigger response: %s", self.util_response.trigger_api_response.data
)
if self.util_response.trigger_api_response.data and isinstance(
self.util_response.trigger_api_response.data, dict
):
self.session_id = self.util_response.trigger_api_response.data.get(
"session", None
)
self.capture_id = self.util_response.trigger_api_response.data.get(
"id", None
)
LOGGER.debug("Extracted session_id: %s", self.session_id)
LOGGER.debug("Extracted capture_id: %s", self.capture_id)
def _on_open(self):
LOGGER.info("WebSocket connection opened")
# Start the max duration timer
self._timeout_handler(Timer.MAX_DURATION, TimerAction.START)
def _on_close(self, code, msg):
LOGGER.info("WebSocket closed: %s - %s", code, msg)
self._stop_all_timers()
self.util_response._queue.put(None) # sentinel for receive()
self.util_response._closed.set() # signal completion
##########################################################################
## Helper methods for managing timers
def _timeout_handler(self, timer_type: Timer, action: TimerAction):
duration = self.timers[timer_type.value]["duration"]
if action == TimerAction.STOP or action == TimerAction.RESET:
if self.timers[timer_type.value]["thread"]:
LOGGER.debug("Stopping %s timer", timer_type.value)
self.timers[timer_type.value]["thread"].cancel()
self.timers[timer_type.value]["thread"] = None
elif action == TimerAction.STOP:
# Only warn when explicitly stopping (not resetting) a non-active timer
LOGGER.warning("%s timer is not active to stop", timer_type.value)
if action == TimerAction.START or action == TimerAction.RESET:
if self.ws:
LOGGER.debug(
"Starting %s timer with duration: %s seconds",
timer_type.value,
duration,
)
self.timers[timer_type.value]["thread"] = threading.Timer(
duration, self.ws.disconnect
)
self.timers[timer_type.value]["thread"].start()
else:
LOGGER.warning(
"WebSocket is not available to start %s timer", timer_type.value
)
def _stop_all_timers(self):
for timer_info in self.timers.values():
if timer_info["thread"]:
timer_info["thread"].cancel()
timer_info["thread"] = None
##########################################################################
## WebSocket event handlers
def _handle_message(self, msg):
if isinstance(msg, dict) and msg.get("event") == "channel_subscribed":
LOGGER.debug("channel_subscribed: %s", msg)
# Start the first message timeout timer when the channel is successfully subscribed
self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.START)
elif self._extract_session_id(msg):
# Stop the first message timeout timer on receiving the first message
self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.STOP)
LOGGER.debug("data: %s", msg)
raw = self._extract_raw(msg)
if raw:
self.data.append(raw)
self.util_response._queue.put(raw) # feed receive() generator
if self._on_message_cb:
self._on_message_cb(raw)
self._timeout_handler(Timer.TIMEOUT, TimerAction.RESET)
##########################################################################
## Message processing and WebSocket connection management
def _extract_session_id(self, message) -> bool:
"""
Extracts the session_id from the message and compares it to the expected session_id.
This method is designed to handle messages that may have the session_id nested at
different levels.
If the expected session_id is None, it will accept all messages.
"""
if not self.session_id and not self.capture_id:
LOGGER.debug("No session_id or capture_id provided, accepting all messages")
return True
if isinstance(message, str):
LOGGER.debug("Trying to decode message: %s", message)
try:
message = json.loads(message)
except json.JSONDecodeError:
LOGGER.warning("Failed to decode message as JSON: %s", message)
return False
if isinstance(message, dict):
if message.get("event") == "data" and message.get("data"):
LOGGER.debug(
"Checking nested data for session_id or capture_id: %s",
message["data"],
)
return self._extract_session_id(message["data"])
if message.get("session") == self.session_id:
LOGGER.info(
"Message session_id matches expected session_id: %s",
self.session_id,
)
return True
if message.get("capture_id") == self.capture_id:
LOGGER.info(
"Message capture_id matches expected capture_id: %s",
self.capture_id,
)
return True
return False
def _extract_raw(self, message):
"""
Extracts the raw message from the given message.
This method is designed to handle messages that may have the raw message nested at
different levels.
Handles both command events (with "raw" field) and pcap events (with "pcap_dict" field).
"""
self.raw_events.append(message)
event = message
if isinstance(event, str):
try:
event = json.loads(event)
except json.JSONDecodeError:
LOGGER.warning("Failed to decode message as JSON: %s", message)
return None
if isinstance(event, dict):
if event.get("event") == "data" and event.get("data"):
return self._extract_raw(event["data"])
if "raw" in event:
self.received_messages += 1
LOGGER.debug("Extracted raw message: %s", event["raw"])
return event["raw"]
if "pcap_dict" in event:
self.received_messages += 1
LOGGER.debug("Extracted pcap data: %s", event["pcap_dict"])
return event["pcap_dict"]
return None
##########################################################################
## WebSocket connection management
def start(self, ws) -> UtilResponse:
"""
Start the WS connection in the background and return immediately.
The returned ``UtilResponse`` collects data as it streams in. Use
``response.receive()``, ``response.wait()``, or the ``on_message``
callback to consume results.
PARAMS
-----------
ws : _MistWebsocket
An already-constructed WebSocket channel object.
"""
self.ws = ws
ws.on_message(self._handle_message)
ws.on_error(lambda error: LOGGER.error("Error: %s", error))
ws.on_close(self._on_close)
ws.on_open(self._on_open)
# Wire up UtilResponse before starting WS
self.util_response.ws_required = True
self.util_response.ws_data = self.data # live list reference
self.util_response.ws_raw_events = self.raw_events
self.util_response._closed.clear() # mark as "in progress"
self.util_response._disconnect_fn = ws.disconnect
ws.connect(run_in_background=True) # non-blocking
return self.util_response