@@ -63,10 +63,26 @@ def rank(self):
6363 self ._rank = parallel .get_dp_rank ()
6464 return self ._rank
6565
66- def setup_data_iter (self ,
67- dataset : "tud.Dataset" ) -> "LoomDataIter" :
68- raise NotImplementedError
69-
66+ def _setup_train_data_iter (self , train_dataset : "tud.Dataset" ):
67+ self .train_data_iter = self .setup_data_iter (train_dataset )
68+ return self .train_data_iter
69+
70+ def _setup_val_data_iter (self , val_dataset : "tud.Dataset" ):
71+ self .val_data_iter = self .setup_data_iter (val_dataset )
72+ return self .val_data_iter
73+
74+
75+ def config_loomDataModule_method (self , datamodule : "LoomDataModule" ):
76+ try :
77+ self .loomDataModule_load_ckpt (None , None )
78+ except NotImplementedError : ...
79+ except Exception as e :
80+ datamodule .load_ckpt = self .loomDataModule_load_ckpt
81+ try :
82+ self .loomDataModule_save_ckpt (None , None )
83+ except NotADirectoryError : ...
84+ except Exception as e :
85+ datamodule .load_ckpt = self .loomDataModule_load_ckpt
7086
7187
7288 def loomDataModule_save_ckpt (self , save_dir : str , tag : str ):
@@ -75,12 +91,11 @@ def loomDataModule_save_ckpt(self, save_dir: str, tag: str):
7591 def loomDataModule_load_ckpt (self , saved_dir : str , tag : str ):
7692 raise NotImplementedError
7793
94+ def setup_data_iter (self ,
95+ dataset : "tud.Dataset" ) -> "LoomDataIter" :
96+ raise NotImplementedError
97+
7898
79- def config_loomDataModule_method (self , datamodule : "LoomDataModule" ):
80- datamodule .save_ckpt = self .loomDataModule_save_ckpt
81- datamodule .load_ckpt = self .loomDataModule_load_ckpt
82-
83-
8499
85100# TBD
86101class TrainStrategy :
0 commit comments