Skip to content

Commit 3b4ec66

Browse files
committed
wip unifikace
1 parent 7624d59 commit 3b4ec66

5 files changed

Lines changed: 583 additions & 276 deletions

File tree

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
import time
5+
from dataclasses import dataclass
6+
from typing import Literal
7+
8+
import numpy as np
9+
from bec_lib import messages
10+
from bec_lib.endpoints import MessageEndpoints
11+
from qtpy.QtCore import QObject, QTimer, Signal
12+
13+
14+
class ProgressSource(enum.Enum):
15+
"""
16+
Enum to define the source of the progress.
17+
"""
18+
19+
SCAN_PROGRESS = "scan_progress"
20+
DEVICE_PROGRESS = "device_progress"
21+
22+
23+
@dataclass(frozen=True)
24+
class ProgressSnapshot:
25+
source: ProgressSource
26+
value: float
27+
max_value: float
28+
done: bool
29+
status: Literal["open", "paused", "aborted", "halted", "closed"]
30+
device: str | None = None
31+
scan_id: str | None = None
32+
scan_number: int | None = None
33+
rid: str | None = None
34+
is_new_scan: bool = False
35+
36+
37+
class ProgressTask(QObject):
38+
"""
39+
Class to store progress information.
40+
Inspired by https://github.com/Textualize/rich/blob/master/rich/progress.py
41+
"""
42+
43+
def __init__(
44+
self, parent: QObject | None, value: float = 0, max_value: float = 0, done: bool = False
45+
):
46+
super().__init__(parent=parent)
47+
self.start_time = time.monotonic()
48+
self.done = done
49+
self.value = value
50+
self.max_value = max_value
51+
self._elapsed_time = 0
52+
53+
self.timer = QTimer(self)
54+
self.timer.timeout.connect(self.update_elapsed_time)
55+
self.timer.start(1000)
56+
57+
def update(self, value: float, max_value: float, done: bool = False):
58+
"""
59+
Update the progress.
60+
"""
61+
self.max_value = max_value
62+
self.done = done
63+
self.value = value
64+
if done:
65+
self.timer.stop()
66+
67+
def update_elapsed_time(self):
68+
"""
69+
Update the time estimates. This is called every second by a QTimer.
70+
"""
71+
self._elapsed_time = max(0.0, time.monotonic() - self.start_time)
72+
73+
@property
74+
def percentage(self) -> float:
75+
"""float: Get progress of task as a percentage. If a None total was set, returns 0"""
76+
if not self.max_value:
77+
return 0.0
78+
completed = (self.value / self.max_value) * 100.0
79+
completed = min(100.0, max(0.0, completed))
80+
return completed
81+
82+
@property
83+
def speed(self) -> float:
84+
"""Get the estimated speed in steps per second."""
85+
if self._elapsed_time == 0:
86+
return 0.0
87+
88+
return self.value / self._elapsed_time
89+
90+
@property
91+
def frequency(self) -> float:
92+
"""Get the estimated frequency in steps per second."""
93+
if self.speed == 0:
94+
return 0.0
95+
return 1 / self.speed
96+
97+
@property
98+
def time_elapsed(self) -> str:
99+
return self._format_time(int(self._elapsed_time))
100+
101+
@property
102+
def remaining(self) -> float:
103+
"""Get the estimated remaining steps."""
104+
if self.done:
105+
return 0.0
106+
remaining = self.max_value - self.value
107+
return remaining
108+
109+
@property
110+
def time_remaining(self) -> str:
111+
"""
112+
Get the estimated remaining time in the format HH:MM:SS.
113+
"""
114+
if self.done or not self.speed or not self.remaining:
115+
return self._format_time(0)
116+
estimate = int(np.round(self.remaining / self.speed))
117+
118+
return self._format_time(estimate)
119+
120+
def _format_time(self, seconds: float) -> str:
121+
"""
122+
Format the time in seconds to a string in the format HH:MM:SS.
123+
"""
124+
return f"{seconds // 3600:02}:{(seconds // 60) % 60:02}:{seconds % 60:02}"
125+
126+
127+
class BECProgressTracker(QObject):
128+
"""
129+
Shared backend for BEC scan and device progress messages.
130+
"""
131+
132+
progress_started = Signal(object)
133+
progress_updated = Signal(object)
134+
progress_finished = Signal(object)
135+
progress_cleared = Signal()
136+
source_changed = Signal(object)
137+
138+
def __init__(self, bec_dispatcher, parent: QObject | None = None):
139+
super().__init__(parent=parent)
140+
self.bec_dispatcher = bec_dispatcher
141+
self._progress_source: ProgressSource | None = None
142+
self._progress_device: str | None = None
143+
self._queue_connected = False
144+
self.task: ProgressTask | None = None
145+
self.scan_number: int | None = None
146+
self._active_scan_id: str | None = None
147+
self._active_rid: str | None = None
148+
149+
@property
150+
def progress_source(self) -> ProgressSource | None:
151+
return self._progress_source
152+
153+
@property
154+
def progress_device(self) -> str | None:
155+
return self._progress_device
156+
157+
@property
158+
def active_scan_id(self) -> str | None:
159+
return self._active_scan_id
160+
161+
@property
162+
def active_rid(self) -> str | None:
163+
return self._active_rid
164+
165+
def start(
166+
self,
167+
*,
168+
source: ProgressSource | None = ProgressSource.SCAN_PROGRESS,
169+
device: str | None = None,
170+
connect_queue: bool = False,
171+
refresh_queue: bool = False,
172+
) -> None:
173+
if source is not None:
174+
self.set_progress_source(source, device=device)
175+
if connect_queue:
176+
self.connect_to_queue()
177+
if refresh_queue:
178+
self.refresh_queue()
179+
180+
def connect_to_queue(self) -> None:
181+
if self._queue_connected:
182+
return
183+
self.bec_dispatcher.connect_slot(self.on_queue_update, MessageEndpoints.scan_queue_status())
184+
self._queue_connected = True
185+
186+
def refresh_queue(self) -> None:
187+
connector = getattr(self.bec_dispatcher.client, "connector", None)
188+
if connector is None:
189+
return
190+
msg = connector.get(MessageEndpoints.scan_queue_status())
191+
if msg is None:
192+
return
193+
self.on_queue_update(msg.content, msg.metadata)
194+
195+
def set_progress_source(self, source: ProgressSource, device: str | None = None) -> None:
196+
if source == ProgressSource.DEVICE_PROGRESS and not device:
197+
return
198+
if self._progress_source == source and self._progress_device == device:
199+
self.source_changed.emit(self.current_snapshot(value=0, max_value=100, done=False))
200+
return
201+
202+
self._disconnect_progress_source()
203+
self._progress_source = source
204+
self._progress_device = None if source == ProgressSource.SCAN_PROGRESS else device
205+
self.bec_dispatcher.connect_slot(self.on_progress_update, self._progress_endpoint())
206+
self.source_changed.emit(self.current_snapshot(value=0, max_value=100, done=False))
207+
208+
def _disconnect_progress_source(self) -> None:
209+
if self._progress_source is None:
210+
return
211+
self.bec_dispatcher.disconnect_slot(self.on_progress_update, self._progress_endpoint())
212+
self._progress_source = None
213+
self._progress_device = None
214+
215+
def _progress_endpoint(self):
216+
if self._progress_source == ProgressSource.SCAN_PROGRESS:
217+
return MessageEndpoints.scan_progress()
218+
return MessageEndpoints.device_progress(device=self._progress_device)
219+
220+
def current_snapshot(
221+
self,
222+
*,
223+
value: float,
224+
max_value: float,
225+
done: bool,
226+
status: Literal["open", "paused", "aborted", "halted", "closed"] = "open",
227+
is_new_scan: bool = False,
228+
) -> ProgressSnapshot:
229+
source = self._progress_source or ProgressSource.SCAN_PROGRESS
230+
return ProgressSnapshot(
231+
source=source,
232+
value=value,
233+
max_value=max_value,
234+
done=done,
235+
status=status,
236+
device=self._progress_device,
237+
scan_id=self._active_scan_id,
238+
scan_number=self.scan_number,
239+
rid=self._active_rid,
240+
is_new_scan=is_new_scan,
241+
)
242+
243+
def _start_task(self, scan_id: str | None, rid: str | None = None) -> None:
244+
if self.task is not None:
245+
self.task.timer.stop()
246+
self.task.deleteLater()
247+
self.task = ProgressTask(parent=self)
248+
self._active_scan_id = scan_id
249+
self._active_rid = rid
250+
self.progress_started.emit(self.current_snapshot(value=0, max_value=100, done=False))
251+
252+
def clear_task(self, *, emit_finished: bool = True) -> None:
253+
if self.task is None:
254+
self._active_scan_id = None
255+
self._active_rid = None
256+
self.progress_cleared.emit()
257+
return
258+
self.task.timer.stop()
259+
self.task.deleteLater()
260+
self.task = None
261+
self._active_scan_id = None
262+
self._active_rid = None
263+
self.progress_cleared.emit()
264+
if emit_finished:
265+
self.progress_finished.emit(self.current_snapshot(value=0, max_value=100, done=True))
266+
267+
def on_queue_update(self, msg_content: dict, metadata: dict):
268+
queue_info = self._extract_active_queue_info(msg_content)
269+
if queue_info is None:
270+
self.clear_task()
271+
self.set_progress_source(ProgressSource.SCAN_PROGRESS)
272+
return
273+
274+
active_request_block = queue_info.active_request_block
275+
scan_id = active_request_block.scan_id or str(active_request_block.scan_number)
276+
rid = getattr(active_request_block, "RID", None) or metadata.get("RID")
277+
if self.task is None or self._active_scan_id != scan_id:
278+
self._start_task(scan_id, rid=rid)
279+
else:
280+
self._active_rid = rid or self._active_rid
281+
282+
self.scan_number = active_request_block.scan_number
283+
self.source_changed.emit(self.current_snapshot(value=0, max_value=100, done=False))
284+
285+
report_instructions = active_request_block.report_instructions
286+
if not report_instructions:
287+
return
288+
289+
instruction = report_instructions[0]
290+
if "scan_progress" in instruction:
291+
self.set_progress_source(ProgressSource.SCAN_PROGRESS)
292+
elif "device_progress" in instruction:
293+
if not instruction["device_progress"]:
294+
return
295+
self.set_progress_source(
296+
ProgressSource.DEVICE_PROGRESS, device=instruction["device_progress"][0]
297+
)
298+
299+
def _extract_active_queue_info(self, msg_content: dict):
300+
if "queue" not in msg_content:
301+
return None
302+
if "primary" not in msg_content["queue"]:
303+
return None
304+
primary_queue = msg_content.get("queue").get("primary")
305+
if primary_queue is None or not isinstance(primary_queue, messages.ScanQueueStatus):
306+
return None
307+
primary_queue_info = primary_queue.info
308+
if len(primary_queue_info) == 0:
309+
return None
310+
scan_info = primary_queue_info[0]
311+
if scan_info is None or scan_info.active_request_block is None:
312+
return None
313+
if scan_info.status.lower() != "running":
314+
return None
315+
return scan_info
316+
317+
def on_progress_update(self, msg_content: dict, metadata: dict):
318+
if self._progress_source is None:
319+
return
320+
self.process_progress_message(self._progress_source, msg_content, metadata)
321+
322+
def process_progress_message(
323+
self,
324+
source: ProgressSource,
325+
msg_content: dict,
326+
metadata: dict,
327+
*,
328+
device: str | None = None,
329+
) -> ProgressSnapshot | None:
330+
done = msg_content.get("done", False)
331+
value = msg_content.get("value", 0)
332+
max_value = msg_content.get("max_value", 100)
333+
if done:
334+
value = max_value
335+
status: Literal["open", "paused", "aborted", "halted", "closed"] = metadata.get(
336+
"status", "open"
337+
)
338+
scan_id = metadata.get("scan_id") or metadata.get("RID")
339+
rid = metadata.get("RID")
340+
is_new_scan = False
341+
previous_scan_id = self._active_scan_id
342+
previous_rid = self._active_rid
343+
identity_changed = (
344+
(scan_id is not None and scan_id != previous_scan_id)
345+
or (rid is not None and rid != previous_rid)
346+
or (previous_scan_id is None and previous_rid is None)
347+
)
348+
349+
if self.task is None:
350+
self._start_task(scan_id, rid=rid)
351+
is_new_scan = identity_changed
352+
elif scan_id is not None and scan_id != self._active_scan_id:
353+
self._start_task(scan_id, rid=rid)
354+
is_new_scan = True
355+
elif rid is not None and rid != self._active_rid:
356+
self._start_task(scan_id or self._active_scan_id, rid=rid)
357+
is_new_scan = True
358+
359+
if self.task is None:
360+
return None
361+
362+
self.task.update(value, max_value, done)
363+
progress_device = device or self._progress_device
364+
snapshot = ProgressSnapshot(
365+
source=source,
366+
value=value,
367+
max_value=max_value,
368+
done=done,
369+
status=status,
370+
device=progress_device if source == ProgressSource.DEVICE_PROGRESS else None,
371+
scan_id=self._active_scan_id,
372+
scan_number=self.scan_number,
373+
rid=self._active_rid,
374+
is_new_scan=is_new_scan,
375+
)
376+
self.progress_updated.emit(snapshot)
377+
if done:
378+
self.clear_task()
379+
return snapshot
380+
381+
def cleanup(self) -> None:
382+
self.clear_task(emit_finished=False)
383+
self._disconnect_progress_source()
384+
if self._queue_connected:
385+
self.bec_dispatcher.disconnect_slot(
386+
self.on_queue_update, MessageEndpoints.scan_queue_status()
387+
)
388+
self._queue_connected = False

0 commit comments

Comments
 (0)