Skip to content

Commit e7c629b

Browse files
Add video modality to unimodal optimizer pipeline
1 parent 574a816 commit e7c629b

17 files changed

Lines changed: 874 additions & 274 deletions

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class VideoStats:
3737
max_height: int
3838
max_channels: int
3939
num_instances: int
40+
num_total_instances: int
4041

4142
@property
4243
def output_shape(self):
@@ -132,8 +133,20 @@ def get_stats(self, source_path: str):
132133
max_height = max(max_height, height)
133134
max_num_channels = max(max_num_channels, num_channels)
134135
num_instances += 1
136+
num_total_instances = num_instances
137+
num_instances = (
138+
min(num_instances, self.chunk_size)
139+
if self.chunk_size is not None
140+
else num_instances
141+
)
135142
return VideoStats(
136-
fps, max_length, max_width, max_height, max_num_channels, num_instances
143+
fps,
144+
max_length,
145+
max_width,
146+
max_height,
147+
max_num_channels,
148+
num_instances,
149+
num_total_instances,
137150
)
138151

139152
def estimate_peak_memory_bytes(self) -> dict:

src/main/python/systemds/scuro/drsearch/modality_shared_memory.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,26 @@ def add_shared_memory_candidate(data: Any, resident_bytes: int = 0) -> bool:
401401
return data, shm.name, data_nbytes, resident_bytes
402402

403403
return None, None, 0, resident_bytes
404+
405+
406+
_SHARED_MEMORY_WRAPPER_TYPES = (
407+
SharedStringList,
408+
SharedGroupedArrayList,
409+
SharedArrayList,
410+
SharedNDArray,
411+
)
412+
413+
414+
def collect_shm_names_from_payload(data: Any) -> List[str]:
415+
if data is None:
416+
return []
417+
if isinstance(data, _SHARED_MEMORY_WRAPPER_TYPES):
418+
return [data.shm_name]
419+
if hasattr(data, "data"):
420+
return collect_shm_names_from_payload(data.data)
421+
if isinstance(data, (list, tuple)):
422+
names: List[str] = []
423+
for item in data:
424+
names.extend(collect_shm_names_from_payload(item))
425+
return names
426+
return []

0 commit comments

Comments
 (0)