Skip to content

Commit 796b431

Browse files
add memory estimation for chache
1 parent cac437e commit 796b431

3 files changed

Lines changed: 44 additions & 8 deletions

File tree

src/main/python/systemds/scuro/drsearch/node_executor.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, as_completed, wait
22
from dataclasses import dataclass
3+
import os
34
from systemds.scuro import Modality
45
from systemds.scuro.drsearch.node_scheduler import MemoryAwareNodeScheduler
56
from systemds.scuro.drsearch.representation_dag import (
@@ -20,6 +21,8 @@
2021
AggregatedRepresentation,
2122
)
2223
from systemds.scuro.representations.unimodal import UnimodalRepresentation
24+
from systemds.scuro.utils.checkpointing import CheckpointManager
25+
from pympler import asizeof
2326

2427

2528
class RefCountResultCache:
@@ -95,19 +98,28 @@ def _execute_task_worker(task: Any, data: Any, gpu_id: Optional[int]) -> Dict[st
9598
return {"scores": scores, "task_time": end - start}
9699

97100

101+
# TODO: add a checkpoint manager only to the node executor, maybe get the name from outside to distinguish between unimodal and multimodal checkpoint managers
102+
# we can exclude all dag nodes that are loaded through an existing checkpoint and therefore speedup the further execution
98103
class NodeExecutor:
99104
def __init__(
100105
self,
101106
dags: List[RepresentationDag],
102107
modalities: List[Modality],
103108
tasks: List[Any],
109+
checkpoint_manager: Optional[CheckpointManager] = None,
104110
max_num_workers: int = -1,
105111
):
106112
available_total_cpu = float(psutil.virtual_memory().available)
107113
self.dags = dags
108114
self.scheduler = MemoryAwareNodeScheduler(
109115
dags, modalities, tasks, available_total_cpu
110116
)
117+
self.checkpoint_manager = CheckpointManager(
118+
checkpoint_dir=os.getcwd(),
119+
prefix="node_executor_checkpoint_",
120+
checkpoint_every=1,
121+
resume=False,
122+
)
111123
self.max_num_workers = (
112124
min(mp.cpu_count(), max_num_workers)
113125
if max_num_workers != -1
@@ -185,8 +197,12 @@ def submit_new_ready_nodes():
185197
self.scheduler.add_failed_node(node_id)
186198
continue
187199

188-
self.scheduler.complete_node(node_id)
200+
before_bytes = self._result_cache_size_bytes()
189201
self._manage_result_cache(node_id, result)
202+
after_bytes = self._result_cache_size_bytes()
203+
self.scheduler.update_cpu_memory_in_use(after_bytes - before_bytes)
204+
self.scheduler.complete_node(node_id)
205+
190206
node = self.scheduler.mapping[node_id]
191207
if self._is_task_node(node):
192208
task_results[node_id].task_time = result["task_time"]
@@ -199,10 +215,17 @@ def submit_new_ready_nodes():
199215
task_results[node_id].test_score = result["scores"][
200216
2
201217
].average_scores
218+
self.checkpoint_manager.increment(node_id)
219+
self.checkpoint_manager.checkpoint_if_due(task_results)
202220
submit_new_ready_nodes()
203221

204222
return list(task_results.values())
205223

224+
def _result_cache_size_bytes(self) -> int:
225+
return asizeof.asizeof(self.result_cache.cache) + asizeof.asizeof(
226+
self.result_cache.ref_count
227+
)
228+
206229
def _manage_result_cache(self, node_id: str, result: Any):
207230
parent_node_id = self.scheduler.get_valid_parent(node_id)
208231
if parent_node_id is not None:

src/main/python/systemds/scuro/drsearch/node_scheduler.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,16 @@ def __init__(
5353
torch.cuda.device_count() if torch and torch.cuda.is_available() else 0
5454
)
5555
self.memory_stats = {
56-
"cpu_in_use": 0.0,
56+
"cpu_in_use": sum([self.node_resources[node][0] for node in self.leaves]),
5757
"gpu_in_use": {
5858
info["index"]: int(info["total_b"] - info["free_b"])
5959
for info in self.gpu_memory_info
6060
},
6161
}
6262

63+
def update_cpu_memory_in_use(self, delta_bytes: int):
64+
self.memory_stats["cpu_in_use"] += delta_bytes
65+
6366
def get_runnable(self) -> List[RepresentationNode]:
6467
runnable_nodes = self._get_runnable_nodes()
6568

@@ -173,7 +176,14 @@ def _get_pending_nodes(self) -> List[str]:
173176

174177
def _reserve_memory(self, node_id: str, gpu_id: int) -> bool:
175178
cpu_mem, gpu_mem = self.node_resources[node_id]
176-
179+
print(
180+
f"Reserving memory for node {node_id}: CPU {cpu_mem} , GPU {gpu_mem} - Total CPU {self.memory_stats['cpu_in_use']}"
181+
+ (
182+
f" , Total GPU {self.memory_stats['gpu_in_use'][gpu_id]}"
183+
if gpu_id is not None
184+
else ""
185+
)
186+
)
177187
self.memory_stats["cpu_in_use"] += cpu_mem
178188
if gpu_id is not None:
179189
self.memory_stats["gpu_in_use"][gpu_id] += gpu_mem

src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def optimize_parallel(self, n_workers=None):
220220
modality.modality_id, new_count
221221
)
222222
self._checkpoint_manager.checkpoint_if_due(
223-
self.operator_performance.results, "eval_count_by_modality"
223+
self.operator_performance.results,
224224
)
225225
except Exception as e:
226226
print(f"Error processing modality {modality.modality_id}: {e}")
@@ -229,7 +229,6 @@ def optimize_parallel(self, n_workers=None):
229229
traceback.print_exc()
230230
self._checkpoint_manager.save_checkpoint(
231231
self.operator_performance.results,
232-
"eval_count_by_modality",
233232
{},
234233
)
235234
continue
@@ -259,7 +258,7 @@ def optimize(self):
259258
new_count = self._count_results(local_result.results)
260259
self._checkpoint_manager.increment(modality.modality_id, new_count)
261260
self._checkpoint_manager.checkpoint_if_due(
262-
self.operator_performance.results, "eval_count_by_modality"
261+
self.operator_performance.results
263262
)
264263
if self.save_all_results:
265264
self.store_results(f"{modality.modality_id}_unimodal_results.pkl")
@@ -269,7 +268,7 @@ def optimize(self):
269268

270269
traceback.print_exc()
271270
self._checkpoint_manager.save_checkpoint(
272-
self.operator_performance.results, "eval_count_by_modality", {}
271+
self.operator_performance.results, {}
273272
)
274273
raise
275274

@@ -336,7 +335,11 @@ def _process_modality(self, modality, skip_remaining: int = 0, scheduler=None):
336335
expanded_dags = self._expand_dags_with_task_roots(dags)
337336

338337
node_executor = NodeExecutor(
339-
expanded_dags, [modality], self.tasks, self.max_num_workers
338+
expanded_dags,
339+
[modality],
340+
self.tasks,
341+
self._checkpoint_manager,
342+
self.max_num_workers,
340343
)
341344
task_results = node_executor.run()
342345

0 commit comments

Comments
 (0)