Skip to content

Commit 9f1d72b

Browse files
committed
feat(bec_emitter, scan_bundler): forward device progress as scan progress
1 parent 67b553a commit 9f1d72b

3 files changed

Lines changed: 378 additions & 27 deletions

File tree

bec_server/bec_server/scan_bundler/bec_emitter.py

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import time
55
from queue import Queue
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Any, cast
77

88
from bec_lib import messages
99
from bec_lib.endpoints import MessageEndpoints
@@ -14,6 +14,8 @@
1414
logger = bec_logger.logger
1515

1616
if TYPE_CHECKING:
17+
from bec_lib.redis_connector import MessageObject
18+
1719
from .scan_bundler import ScanBundler
1820

1921

@@ -22,6 +24,7 @@ def __init__(self, scan_bundler: ScanBundler) -> None:
2224
super().__init__(scan_bundler.connector)
2325
self._send_buffer = Queue()
2426
self.scan_bundler = scan_bundler
27+
self._device_progress_subscriptions: dict[str, dict[str, Any]] = {}
2528
self._buffered_connector_thread = None
2629
self._buffered_publisher_stop_event = threading.Event()
2730
self._start_buffered_connector()
@@ -96,7 +99,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None:
9699
MessageEndpoints.scan_segment(),
97100
MessageEndpoints.public_scan_segment(scan_id=scan_id, point_id=point_id),
98101
)
99-
self._update_scan_progress(scan_id, point_id)
102+
if not self._has_device_progress_subscription(scan_id):
103+
self._update_scan_progress(scan_id, point_id)
100104

101105
def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None:
102106
if scan_id not in self.scan_bundler.sync_storage:
@@ -107,18 +111,37 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None
107111
info = self.scan_bundler.sync_storage[scan_id]["info"]
108112

109113
num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0))
114+
value = point_id + 1
115+
max_value = num_monitored_readouts or point_id + 1
116+
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=done)
117+
118+
def send_scan_progress(self, scan_id: str, value: float, max_value: float, done=False) -> None:
119+
"""
120+
Send a scan progress update.
110121
122+
Args:
123+
scan_id (str): The ID of the scan.
124+
value (float): The current progress value.
125+
max_value (float): The maximum progress value.
126+
done (bool): Whether the scan is done.
127+
"""
128+
storage = self.scan_bundler.sync_storage.get(scan_id)
129+
if not storage:
130+
return
131+
info = storage["info"]
111132
msg = messages.ProgressMessage(
112-
value=point_id + 1,
113-
max_value=num_monitored_readouts or point_id + 1,
133+
value=value,
134+
max_value=max_value,
114135
done=done,
115136
metadata={
116137
"scan_id": scan_id,
117138
"RID": info.get("RID", ""),
118139
"queue_id": info.get("queue_id", ""),
119-
"status": self.scan_bundler.sync_storage[scan_id]["status"],
140+
"status": storage["status"],
120141
},
121142
)
143+
storage["last_progress_sent"] = msg
144+
logger.info(f"Emitting progress for scan {scan_id}: {value}/{max_value} (done={done})")
122145
self.scan_bundler.connector.set_and_publish(MessageEndpoints.scan_progress(), msg)
123146

124147
def _send_baseline(self, scan_id: str) -> None:
@@ -141,29 +164,95 @@ def _send_baseline(self, scan_id: str) -> None:
141164
pipe.execute()
142165

143166
def on_scan_status_update(self, status_msg: messages.ScanStatusMessage):
167+
sb = self.scan_bundler
168+
if status_msg.scan_id not in sb.sync_storage:
169+
logger.warning(
170+
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
171+
)
172+
return
173+
144174
if status_msg.status == "open":
145-
# No need to update progress for an open scan. This is handled by the scan point emit.
175+
# Update progress subscription:
176+
# - If the scan report instruction contains "scan_progress", we simply emit
177+
# progress updates as they come in.
178+
# - If the scan report instruction contains "device_progress", we subscribe
179+
# to the progress of the first device and use that as the progress for the whole scan.
180+
self._update_device_progress_subscription(status_msg.scan_id)
146181
return
147182

148183
num_points = max(status_msg.info.get("num_points", 0) - 1, 0)
149-
num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points)
184+
num_monitored_readouts = status_msg.info.get("num_monitored_readouts")
185+
if num_monitored_readouts is not None:
186+
num_monitored_readouts = max(num_monitored_readouts - 1, 0)
187+
else:
188+
num_monitored_readouts = num_points
150189
if status_msg.status == "closed":
151-
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
152-
return
190+
if not self._has_device_progress_subscription(status_msg.scan_id):
191+
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
192+
return
153193

154-
sb = self.scan_bundler
155-
if status_msg.scan_id not in sb.sync_storage:
156-
logger.warning(
157-
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
158-
)
194+
self._unregister_device_progress_subscription(status_msg.scan_id)
195+
self._emit_last_progress(status_msg.scan_id)
159196
return
197+
198+
# Scan is not open or closed but instead in ["aborted", "halted", "user_completed"]
160199
storage = sb.sync_storage[status_msg.scan_id]
200+
if self._has_device_progress_subscription(status_msg.scan_id):
201+
self._unregister_device_progress_subscription(status_msg.scan_id)
202+
self._emit_last_progress(status_msg.scan_id)
203+
return
161204
sent_vals = storage.get("sent", {0}) or {0}
162205
max_point = max(sent_vals)
163206
self._update_scan_progress(status_msg.scan_id, max_point, done=True)
164207

208+
def on_cleanup(self, scan_id: str):
209+
self._unregister_device_progress_subscription(scan_id)
210+
165211
def shutdown(self):
166212
if self._buffered_connector_thread:
167213
self._buffered_publisher_stop_event.set()
168214
self._buffered_connector_thread.join()
169215
self._buffered_connector_thread = None
216+
217+
#############################################################
218+
################# Device Progress Helpers ###################
219+
#############################################################
220+
221+
def _update_device_progress_subscription(self, scan_id: str):
222+
sb = self.scan_bundler
223+
instructions = sb.scan_report_instructions.get(scan_id, [])
224+
if self._has_device_progress_subscription(scan_id):
225+
return
226+
for instruction in instructions:
227+
if "device_progress" in instruction:
228+
device = instruction["device_progress"][0]
229+
sub = {
230+
"topics": MessageEndpoints.device_progress(device=device),
231+
"cb": lambda msg_obj, _scan_id=scan_id: self._on_device_progress(
232+
msg_obj, _scan_id
233+
),
234+
}
235+
self._device_progress_subscriptions[scan_id] = sub
236+
self.connector.register(**sub)
237+
return
238+
239+
def _emit_last_progress(self, scan_id: str):
240+
storage = self.scan_bundler.sync_storage.get(scan_id, {})
241+
msg = storage.get("last_progress_sent")
242+
value = msg.value if msg else 0
243+
max_value = msg.max_value if msg else 0
244+
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=True)
245+
246+
def _on_device_progress(self, msg_obj: MessageObject, scan_id: str):
247+
msg = cast(messages.ProgressMessage, msg_obj.value)
248+
if msg.metadata.get("scan_id") != scan_id:
249+
return
250+
self.send_scan_progress(scan_id, value=msg.value, max_value=msg.max_value, done=msg.done)
251+
252+
def _has_device_progress_subscription(self, scan_id: str) -> bool:
253+
return scan_id in self._device_progress_subscriptions
254+
255+
def _unregister_device_progress_subscription(self, scan_id: str) -> None:
256+
sub_info = self._device_progress_subscriptions.pop(scan_id, None)
257+
if sub_info:
258+
self.connector.unregister(**sub_info)

bec_server/bec_server/scan_bundler/scan_bundler.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import traceback
77
from collections.abc import Callable
88
from concurrent.futures import ThreadPoolExecutor
9-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, cast
1010

1111
from bec_lib import messages
1212
from bec_lib.bec_service import BECService
@@ -17,7 +17,7 @@
1717
from .bec_emitter import BECEmitter
1818

1919
if TYPE_CHECKING:
20-
from bec_lib.redis_connector import RedisConnector
20+
from bec_lib.redis_connector import MessageObject, RedisConnector
2121

2222

2323
logger = bec_logger.logger
@@ -35,19 +35,23 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None:
3535
name="device_read_register",
3636
)
3737
self.connector.register(MessageEndpoints.scan_status(), cb=self._scan_status_callback)
38-
3938
self.sync_storage = {}
4039
self.monitored_devices = {}
4140
self.baseline_devices = {}
4241
self.device_storage = {}
4342
self.readout_priority = {}
43+
self.scan_queue: messages.ScanQueueStatusMessage | None = None
44+
self.scan_report_instructions: dict[str, list] = {}
4445
self.storage_initialized = set()
4546
self.executor = ThreadPoolExecutor(max_workers=4)
4647
self.executor_tasks = collections.deque(maxlen=100)
4748
self.scan_id_history = collections.deque(maxlen=10)
4849
self._lock = threading.Lock()
4950
self._emitter = []
5051
self._initialize_emitters()
52+
self.connector.register(
53+
MessageEndpoints.scan_queue_status(), cb=self.on_scan_queue_status_update
54+
)
5155
self.status = messages.BECStatus.RUNNING
5256

5357
def _initialize_emitters(self):
@@ -95,6 +99,34 @@ def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None:
9599
self._scan_status_modification(msg)
96100
self.run_emitter("on_scan_status_update", msg)
97101

102+
def on_scan_queue_status_update(self, msg_obj: MessageObject):
103+
"""
104+
Update the scan_report_instructions based on the active request block
105+
in the scan queue status message.
106+
107+
Args:
108+
msg_obj (MessageObject): The message object containing the scan queue status update.
109+
"""
110+
status_msg = cast(messages.ScanQueueStatusMessage, msg_obj.value)
111+
for scan_queue_status in status_msg.queue.values():
112+
if not scan_queue_status.info:
113+
continue
114+
info = scan_queue_status.info[0]
115+
active_request_block = info.active_request_block
116+
if not active_request_block:
117+
continue
118+
scan_id = active_request_block.scan_id
119+
if scan_id is None:
120+
continue
121+
report_instructions = active_request_block.report_instructions
122+
if not report_instructions:
123+
continue
124+
125+
self.scan_report_instructions[scan_id] = report_instructions
126+
logger.debug(
127+
f"Updated report instructions for scan_id {scan_id}: {report_instructions}"
128+
)
129+
98130
def _scan_status_modification(self, msg: messages.ScanStatusMessage):
99131
status = msg.content.get("status")
100132
if status not in ["closed", "aborted", "paused", "halted", "user_completed"]:
@@ -358,17 +390,18 @@ def cleanup_storage(self):
358390
remove_scan_ids.append(scan_id)
359391

360392
for scan_id in remove_scan_ids:
393+
self.run_emitter("on_cleanup", scan_id)
361394
for storage in [
362395
"sync_storage",
363396
"monitored_devices",
364397
"baseline_devices",
365398
"readout_priority",
399+
"scan_report_instructions",
366400
]:
367401
try:
368402
getattr(self, storage).pop(scan_id)
369403
except KeyError:
370404
logger.warning(f"Failed to remove {scan_id} from {storage}.")
371-
self.run_emitter("on_cleanup", scan_id)
372405
self.storage_initialized.remove(scan_id)
373406

374407
def _send_scan_point(self, scan_id, point_id) -> None:

0 commit comments

Comments
 (0)