Skip to content

Commit 50e4b0d

Browse files
include a chunk system for datasets
1 parent 42010e8 commit 50e4b0d

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

nebula/core/datasets/nebuladataset.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ def set_data(self, data, targets, data_opt=None, targets_opt=None):
129129
self.targets = targets[:main_count] + targets_opt[:opt_count]
130130
self.length = len(self.data)
131131

132+
indices = np.arange(self.length)
133+
np.random.shuffle(indices)
134+
if isinstance(self.data, np.ndarray):
135+
self.data = self.data[indices]
136+
else:
137+
self.data = [self.data[i] for i in indices]
138+
if isinstance(self.targets, np.ndarray):
139+
self.targets = self.targets[indices]
140+
else:
141+
self.targets = [self.targets[i] for i in indices]
142+
132143
except Exception as e:
133144
logging_training.exception(f"Error setting data: {e}")
134145

@@ -139,6 +150,9 @@ def load_partition(self, file, name):
139150
if typ == "pickle":
140151
logging_training.info(f"Loading pickled object from {name}")
141152
return pickle.loads(item[()].tobytes())
153+
elif typ == "pickle_bytes":
154+
logging_training.info(f"Loading compressed pickled bytes object from {name}")
155+
return pickle.loads(item[()])
142156
else:
143157
logging_training.warning(f"[NebulaPartitionHandler] Unknown type encountered: {typ} for item {name}")
144158
return item[()]
@@ -402,8 +416,26 @@ def save_partition(self, obj, file, name):
402416
try:
403417
logging.info(f"Saving pickled object of type {type(obj)}")
404418
pickled = pickle.dumps(obj)
405-
ds = file.create_dataset(name, data=np.void(pickled))
406-
ds.attrs["__type__"] = "pickle"
419+
420+
size_in_mb = len(pickled) / (1024 * 1024)
421+
logging.info(f"Pickled object size: {size_in_mb:.2f} MB")
422+
423+
if size_in_mb > 10:
424+
logging.info(f"Large object detected ({size_in_mb:.2f} MB). Using chunked storage with compression.")
425+
data = np.frombuffer(pickled, dtype=np.uint8)
426+
chunk_size = min(4 * 1024 * 1024, len(data) // 10)
427+
chunk_length = max(1, chunk_size // data.itemsize)
428+
ds = file.create_dataset(
429+
name,
430+
data=data,
431+
chunks=(chunk_length,),
432+
compression="lzf",
433+
shuffle=True,
434+
)
435+
ds.attrs["__type__"] = "pickle_bytes"
436+
else:
437+
ds = file.create_dataset(name, data=np.void(pickled))
438+
ds.attrs["__type__"] = "pickle"
407439
logging.info(f"Saved pickled object of type {type(obj)} to {name}")
408440
except Exception as e:
409441
logging.exception(f"Error saving object to HDF5: {e}")

0 commit comments

Comments
 (0)