Skip to content

Commit d0781f0

Browse files
committed
update DataStrategy
1 parent 3f0ba85 commit d0781f0

3 files changed

Lines changed: 39 additions & 17 deletions

File tree

loomtrain/core/data/dataloader/iter.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,22 @@ def __next__(self):
5454
def __iter__(self):
5555
for epoch in range(self.num_epochs):
5656
self._current_epoch = epoch
57-
if epoch < self.start_epoch:continue
58-
self.load_sampler_state(epoch, self.start_epoch)
57+
if epoch < self.consumed_epoch:continue
58+
self.set_sampler_state(epoch, self.consumed_epoch, self.consumed_samples)
5959
yield from iter(super().__iter__())
60-
6160

62-
def set_state(self, start_epoch: int, consumed_samples = 0):
63-
self.start_epoch = start_epoch
61+
62+
63+
64+
def set_sampler_state(self, current_epoch: int, consumed_epoch:int, consumed_samples):
65+
raise NotImplementedError
66+
67+
68+
def set_state(self, consumed_epoch: int, consumed_samples = 0):
69+
self.consumed_epoch = consumed_epoch
6470
self.consumed_samples = consumed_samples
6571

6672

6773
def get_state(self) -> dict:
68-
...
74+
raise NotImplementedError
75+

loomtrain/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def _setup_dataset(self):
210210

211211

212212
def _setup_train_data_iter(self):
213-
self.train_data_iter = self.strategy.setup_data_iter(self.train_dataset)
213+
self.train_data_iter = self.strategy._setup_train_data_iter(self.train_dataset)
214214

215215
def _setup_val_data_iter(self):
216-
self.val_data_iter = self.strategy.setup_data_iter(self.val_dataset)
216+
self.val_data_iter = self.strategy._setup_val_data_iter(self.val_dataset)
217217

218218

219219
def get_saved_sub_dir(self): return "data_iter"

loomtrain/core/strategy.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,26 @@ def rank(self):
6363
self._rank = parallel.get_dp_rank()
6464
return self._rank
6565

66-
def setup_data_iter(self,
67-
dataset: "tud.Dataset") -> "LoomDataIter":
68-
raise NotImplementedError
69-
66+
def _setup_train_data_iter(self, train_dataset: "tud.Dataset"):
67+
self.train_data_iter = self.setup_data_iter(train_dataset)
68+
return self.train_data_iter
69+
70+
def _setup_val_data_iter(self, val_dataset: "tud.Dataset"):
71+
self.val_data_iter = self.setup_data_iter(val_dataset)
72+
return self.val_data_iter
73+
74+
75+
def config_loomDataModule_method(self, datamodule: "LoomDataModule"):
76+
try:
77+
self.loomDataModule_load_ckpt(None, None)
78+
except NotImplementedError: ...
79+
except Exception as e:
80+
datamodule.load_ckpt = self.loomDataModule_load_ckpt
81+
try:
82+
self.loomDataModule_save_ckpt(None, None)
83+
except NotADirectoryError: ...
84+
except Exception as e:
85+
datamodule.load_ckpt = self.loomDataModule_load_ckpt
7086

7187

7288
def loomDataModule_save_ckpt(self, save_dir: str, tag: str):
@@ -75,12 +91,11 @@ def loomDataModule_save_ckpt(self, save_dir: str, tag: str):
7591
def loomDataModule_load_ckpt(self, saved_dir: str, tag: str):
7692
raise NotImplementedError
7793

94+
def setup_data_iter(self,
95+
dataset: "tud.Dataset") -> "LoomDataIter":
96+
raise NotImplementedError
97+
7898

79-
def config_loomDataModule_method(self, datamodule: "LoomDataModule"):
80-
datamodule.save_ckpt = self.loomDataModule_save_ckpt
81-
datamodule.load_ckpt = self.loomDataModule_load_ckpt
82-
83-
8499

85100
# TBD
86101
class TrainStrategy:

0 commit comments

Comments
 (0)