Skip to content

Commit 03ec110

Browse files
add updated memory estimators
1 parent e282a65 commit 03ec110

28 files changed

Lines changed: 1496 additions & 408 deletions

src/main/python/systemds/scuro/dataloader/image_loader.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class ImageStats:
3636
max_channels: int
3737
num_instances: int
3838
output_shape: tuple
39+
average_width: int
40+
average_height: int
41+
average_channels: int
3942

4043

4144
class ImageLoader(BaseLoader):
@@ -79,7 +82,9 @@ def get_stats(self, source_path: str):
7982
max_height = 0
8083
max_channels = 0
8184
num_instances = 0
82-
85+
average_width = 0
86+
average_height = 0
87+
average_channels = 0
8388
for file in self.indices:
8489
path = os.path.join(source_path, f"{file}{self._ext}")
8590
# if self.chunk_size is None:
@@ -98,10 +103,19 @@ def get_stats(self, source_path: str):
98103
max_height = max(max_height, height)
99104
max_channels = max(max_channels, channels)
100105
num_instances += 1
106+
average_width += width
107+
average_height += height
108+
average_channels += channels
109+
average_width = average_width / num_instances
110+
average_height = average_height / num_instances
111+
average_channels = average_channels / num_instances
101112
return ImageStats(
102113
max_width,
103114
max_height,
104115
max_channels,
105116
num_instances,
106-
(max_width, max_height, max_channels),
117+
(average_width, average_height, average_channels),
118+
average_width,
119+
average_height,
120+
average_channels,
107121
)

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ class JSONStats:
3333
num_instances: int
3434
max_length: int
3535
avg_length: float
36-
output_shape: Tuple[int, int]
36+
max_words: int
37+
avg_words: float
38+
output_shape: Tuple[int]
3739

3840

3941
class JSONLoader(BaseLoader):
@@ -74,6 +76,8 @@ def get_stats(self, source_path: str):
7476
num_instances = 0
7577
max_length = 0
7678
avg_length = 0
79+
max_words = 0
80+
avg_words = 0
7781
if os.path.isfile(source_path):
7882
with open(source_path) as f:
7983
json_file = json.load(f)
@@ -96,10 +100,20 @@ def get_stats(self, source_path: str):
96100
text = " ".join(text) if isinstance(text, list) else text
97101
num_instances += 1
98102
max_length = max(max_length, len(text)) # number of characters
103+
max_words = max(max_words, len(text.split()))
104+
avg_words += len(text.split())
99105
avg_length += len(text)
100106

101107
avg_length /= num_instances
102-
return JSONStats(num_instances, max_length, avg_length, (max_length,))
108+
avg_words /= num_instances
109+
return JSONStats(
110+
num_instances,
111+
max_length,
112+
avg_length,
113+
max_words,
114+
avg_words,
115+
(max_length,),
116+
)
103117

104118
def estimate_peak_memory_bytes(self) -> dict:
105119
s = self.stats

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class TextStats:
3131
num_instances: int
3232
max_length: int
3333
avg_length: float
34+
max_words: int
35+
avg_words: float
3436
output_shape: tuple
3537

3638

@@ -63,6 +65,8 @@ def get_stats(self, source_path: str):
6365
num_instances = 0
6466
max_length = 0
6567
avg_length = 0
68+
max_words = 0
69+
avg_words = 0
6670
for file in os.listdir(source_path):
6771
self.file_sanity_check(source_path + file)
6872
with open(source_path + file) as text_file:
@@ -74,5 +78,10 @@ def get_stats(self, source_path: str):
7478
num_instances += 1
7579
max_length = max(max_length, length)
7680
avg_length += length
81+
max_words = max(max_words, len(line.split()))
82+
avg_words += len(line.split())
7783
avg_length /= num_instances
78-
return TextStats(num_instances, max_length, avg_length, (max_length,))
84+
avg_words /= num_instances
85+
return TextStats(
86+
num_instances, max_length, avg_length, max_words, avg_words, (max_length,)
87+
)

0 commit comments

Comments
 (0)