33import threading
44import time
55from queue import Queue
6- from typing import TYPE_CHECKING
6+ from typing import TYPE_CHECKING , Any , cast
77
88from bec_lib import messages
99from bec_lib .endpoints import MessageEndpoints
1414logger = bec_logger .logger
1515
1616if TYPE_CHECKING :
17+ from bec_lib .redis_connector import MessageObject
18+
1719 from .scan_bundler import ScanBundler
1820
1921
@@ -22,6 +24,7 @@ def __init__(self, scan_bundler: ScanBundler) -> None:
2224 super ().__init__ (scan_bundler .connector )
2325 self ._send_buffer = Queue ()
2426 self .scan_bundler = scan_bundler
27+ self ._device_progress_subscriptions : dict [str , dict [str , Any ]] = {}
2528 self ._buffered_connector_thread = None
2629 self ._buffered_publisher_stop_event = threading .Event ()
2730 self ._start_buffered_connector ()
@@ -96,7 +99,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None:
9699 MessageEndpoints .scan_segment (),
97100 MessageEndpoints .public_scan_segment (scan_id = scan_id , point_id = point_id ),
98101 )
99- self ._update_scan_progress (scan_id , point_id )
102+ if not self ._has_device_progress_subscription (scan_id ):
103+ self ._update_scan_progress (scan_id , point_id )
100104
101105 def _update_scan_progress (self , scan_id : str , point_id : int , done = False ) -> None :
102106 if scan_id not in self .scan_bundler .sync_storage :
@@ -107,18 +111,37 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None
107111 info = self .scan_bundler .sync_storage [scan_id ]["info" ]
108112
109113 num_monitored_readouts = info .get ("num_monitored_readouts" , info .get ("num_points" , 0 ))
114+ value = point_id + 1
115+ max_value = num_monitored_readouts or point_id + 1
116+ self .send_scan_progress (scan_id , value = value , max_value = max_value , done = done )
117+
118+ def send_scan_progress (self , scan_id : str , value : float , max_value : float , done = False ) -> None :
119+ """
120+ Send a scan progress update.
110121
122+ Args:
123+ scan_id (str): The ID of the scan.
124+ value (float): The current progress value.
125+ max_value (float): The maximum progress value.
126+ done (bool): Whether the scan is done.
127+ """
128+ storage = self .scan_bundler .sync_storage .get (scan_id )
129+ if not storage :
130+ return
131+ info = storage ["info" ]
111132 msg = messages .ProgressMessage (
112- value = point_id + 1 ,
113- max_value = num_monitored_readouts or point_id + 1 ,
133+ value = value ,
134+ max_value = max_value ,
114135 done = done ,
115136 metadata = {
116137 "scan_id" : scan_id ,
117138 "RID" : info .get ("RID" , "" ),
118139 "queue_id" : info .get ("queue_id" , "" ),
119- "status" : self . scan_bundler . sync_storage [ scan_id ] ["status" ],
140+ "status" : storage ["status" ],
120141 },
121142 )
143+ storage ["last_progress_sent" ] = msg
144+ logger .info (f"Emitting progress for scan { scan_id } : { value } /{ max_value } (done={ done } )" )
122145 self .scan_bundler .connector .set_and_publish (MessageEndpoints .scan_progress (), msg )
123146
124147 def _send_baseline (self , scan_id : str ) -> None :
@@ -141,29 +164,95 @@ def _send_baseline(self, scan_id: str) -> None:
141164 pipe .execute ()
142165
143166 def on_scan_status_update (self , status_msg : messages .ScanStatusMessage ):
167+ sb = self .scan_bundler
168+ if status_msg .scan_id not in sb .sync_storage :
169+ logger .warning (
170+ f"Cannot update scan progress: Scan { status_msg .scan_id } not found in sync storage."
171+ )
172+ return
173+
144174 if status_msg .status == "open" :
145- # No need to update progress for an open scan. This is handled by the scan point emit.
175+ # Update progress subscription:
176+ # - If the scan report instruction contains "scan_progress", we simply emit
177+ # progress updates as they come in.
178+ # - If the scan report instruction contains "device_progress", we subscribe
179+ # to the progress of the first device and use that as the progress for the whole scan.
180+ self ._update_device_progress_subscription (status_msg .scan_id )
146181 return
147182
148183 num_points = max (status_msg .info .get ("num_points" , 0 ) - 1 , 0 )
149- num_monitored_readouts = status_msg .info .get ("num_monitored_readouts" , num_points )
184+ num_monitored_readouts = status_msg .info .get ("num_monitored_readouts" )
185+ if num_monitored_readouts is not None :
186+ num_monitored_readouts = max (num_monitored_readouts - 1 , 0 )
187+ else :
188+ num_monitored_readouts = num_points
150189 if status_msg .status == "closed" :
151- self ._update_scan_progress (status_msg .scan_id , num_monitored_readouts , done = True )
152- return
190+ if not self ._has_device_progress_subscription (status_msg .scan_id ):
191+ self ._update_scan_progress (status_msg .scan_id , num_monitored_readouts , done = True )
192+ return
153193
154- sb = self .scan_bundler
155- if status_msg .scan_id not in sb .sync_storage :
156- logger .warning (
157- f"Cannot update scan progress: Scan { status_msg .scan_id } not found in sync storage."
158- )
194+ self ._unregister_device_progress_subscription (status_msg .scan_id )
195+ self ._emit_last_progress (status_msg .scan_id )
159196 return
197+
198+ # Scan is not open or closed but instead in ["aborted", "halted", "user_completed"]
160199 storage = sb .sync_storage [status_msg .scan_id ]
200+ if self ._has_device_progress_subscription (status_msg .scan_id ):
201+ self ._unregister_device_progress_subscription (status_msg .scan_id )
202+ self ._emit_last_progress (status_msg .scan_id )
203+ return
161204 sent_vals = storage .get ("sent" , {0 }) or {0 }
162205 max_point = max (sent_vals )
163206 self ._update_scan_progress (status_msg .scan_id , max_point , done = True )
164207
208+ def on_cleanup (self , scan_id : str ):
209+ self ._unregister_device_progress_subscription (scan_id )
210+
165211 def shutdown (self ):
166212 if self ._buffered_connector_thread :
167213 self ._buffered_publisher_stop_event .set ()
168214 self ._buffered_connector_thread .join ()
169215 self ._buffered_connector_thread = None
216+
217+ #############################################################
218+ ################# Device Progress Helpers ###################
219+ #############################################################
220+
221+ def _update_device_progress_subscription (self , scan_id : str ):
222+ sb = self .scan_bundler
223+ instructions = sb .scan_report_instructions .get (scan_id , [])
224+ if self ._has_device_progress_subscription (scan_id ):
225+ return
226+ for instruction in instructions :
227+ if "device_progress" in instruction :
228+ device = instruction ["device_progress" ][0 ]
229+ sub = {
230+ "topics" : MessageEndpoints .device_progress (device = device ),
231+ "cb" : lambda msg_obj , _scan_id = scan_id : self ._on_device_progress (
232+ msg_obj , _scan_id
233+ ),
234+ }
235+ self ._device_progress_subscriptions [scan_id ] = sub
236+ self .connector .register (** sub )
237+ return
238+
239+ def _emit_last_progress (self , scan_id : str ):
240+ storage = self .scan_bundler .sync_storage .get (scan_id , {})
241+ msg = storage .get ("last_progress_sent" )
242+ value = msg .value if msg else 0
243+ max_value = msg .max_value if msg else 0
244+ self .send_scan_progress (scan_id , value = value , max_value = max_value , done = True )
245+
246+ def _on_device_progress (self , msg_obj : MessageObject , scan_id : str ):
247+ msg = cast (messages .ProgressMessage , msg_obj .value )
248+ if msg .metadata .get ("scan_id" ) != scan_id :
249+ return
250+ self .send_scan_progress (scan_id , value = msg .value , max_value = msg .max_value , done = msg .done )
251+
252+ def _has_device_progress_subscription (self , scan_id : str ) -> bool :
253+ return scan_id in self ._device_progress_subscriptions
254+
255+ def _unregister_device_progress_subscription (self , scan_id : str ) -> None :
256+ sub_info = self ._device_progress_subscriptions .pop (scan_id , None )
257+ if sub_info :
258+ self .connector .unregister (** sub_info )
0 commit comments