11from concurrent .futures import FIRST_COMPLETED , ProcessPoolExecutor , as_completed , wait
22from dataclasses import dataclass
3+ import os
34from systemds .scuro import Modality
45from systemds .scuro .drsearch .node_scheduler import MemoryAwareNodeScheduler
56from systemds .scuro .drsearch .representation_dag import (
2021 AggregatedRepresentation ,
2122)
2223from systemds .scuro .representations .unimodal import UnimodalRepresentation
24+ from systemds .scuro .utils .checkpointing import CheckpointManager
25+ from pympler import asizeof
2326
2427
2528class RefCountResultCache :
@@ -95,19 +98,28 @@ def _execute_task_worker(task: Any, data: Any, gpu_id: Optional[int]) -> Dict[st
9598 return {"scores" : scores , "task_time" : end - start }
9699
97100
101+ # TODO: add a checkpoint manager only to the node executor, maybe get the name from outside to distinguish between unimodal and multimodal checkpoint managers
102+ # we can exclude all dag nodes that are loaded through an existing checkpoint and therefore speedup the further execution
98103class NodeExecutor :
99104 def __init__ (
100105 self ,
101106 dags : List [RepresentationDag ],
102107 modalities : List [Modality ],
103108 tasks : List [Any ],
109+ checkpoint_manager : Optional [CheckpointManager ] = None ,
104110 max_num_workers : int = - 1 ,
105111 ):
106112 available_total_cpu = float (psutil .virtual_memory ().available )
107113 self .dags = dags
108114 self .scheduler = MemoryAwareNodeScheduler (
109115 dags , modalities , tasks , available_total_cpu
110116 )
117+ self .checkpoint_manager = CheckpointManager (
118+ checkpoint_dir = os .getcwd (),
119+ prefix = "node_executor_checkpoint_" ,
120+ checkpoint_every = 1 ,
121+ resume = False ,
122+ )
111123 self .max_num_workers = (
112124 min (mp .cpu_count (), max_num_workers )
113125 if max_num_workers != - 1
@@ -185,8 +197,12 @@ def submit_new_ready_nodes():
185197 self .scheduler .add_failed_node (node_id )
186198 continue
187199
188- self .scheduler . complete_node ( node_id )
200+ before_bytes = self ._result_cache_size_bytes ( )
189201 self ._manage_result_cache (node_id , result )
202+ after_bytes = self ._result_cache_size_bytes ()
203+ self .scheduler .update_cpu_memory_in_use (after_bytes - before_bytes )
204+ self .scheduler .complete_node (node_id )
205+
190206 node = self .scheduler .mapping [node_id ]
191207 if self ._is_task_node (node ):
192208 task_results [node_id ].task_time = result ["task_time" ]
@@ -199,10 +215,17 @@ def submit_new_ready_nodes():
199215 task_results [node_id ].test_score = result ["scores" ][
200216 2
201217 ].average_scores
218+ self .checkpoint_manager .increment (node_id )
219+ self .checkpoint_manager .checkpoint_if_due (task_results )
202220 submit_new_ready_nodes ()
203221
204222 return list (task_results .values ())
205223
224+ def _result_cache_size_bytes (self ) -> int :
225+ return asizeof .asizeof (self .result_cache .cache ) + asizeof .asizeof (
226+ self .result_cache .ref_count
227+ )
228+
206229 def _manage_result_cache (self , node_id : str , result : Any ):
207230 parent_node_id = self .scheduler .get_valid_parent (node_id )
208231 if parent_node_id is not None :
0 commit comments