diff --git a/openstl/api/exp.py b/openstl/api/exp.py index 77f233b5..8325690d 100644 --- a/openstl/api/exp.py +++ b/openstl/api/exp.py @@ -92,13 +92,13 @@ def _get_data(self, dataloaders=None): return BaseDataModule(train_loader, vali_loader, test_loader) def train(self): - self.trainer.fit(self.method, self.data) + self.trainer.fit(self.method, self.data.train_loader) def test(self): if self.args.test == True: ckpt = torch.load(osp.join(self.save_dir, 'checkpoints', 'best.ckpt')) self.method.load_state_dict(ckpt['state_dict']) - self.trainer.test(self.method, self.data) + self.trainer.test(self.method, self.data.test_loader) def display_method_info(self, args): """Plot the basic infomation of supported methods""" @@ -142,4 +142,4 @@ def display_method_info(self, args): fps = 'Throughputs of {}: {:.3f}\n'.format(args.method, fps) else: fps = '' - return info, flops, fps, dash_line \ No newline at end of file + return info, flops, fps, dash_line diff --git a/openstl/utils/config_utils.py b/openstl/utils/config_utils.py index b5714449..71b36588 100644 --- a/openstl/utils/config_utils.py +++ b/openstl/utils/config_utils.py @@ -5,6 +5,7 @@ import sys import ast from importlib import import_module +import os ''' Thanks the code from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py wrote by Open-MMLab. @@ -58,8 +59,10 @@ def _substitute_predefined_vars(filename, temp_config_name): regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' value = value.replace('\\', '/') config_file = re.sub(regexp, value, config_file) - with open(temp_config_name, 'w') as tmp_config_file: - tmp_config_file.write(config_file) + # with open(temp_config_name, 'w') as tmp_config_file: + # tmp_config_file.write(config_file) + temp_config_file = open(temp_config_name, 'w') + temp_config_file.write(config_file) @staticmethod def _file2dict(filename, use_predefined_variables=True): @@ -71,7 +74,7 @@ def _file2dict(filename, use_predefined_variables=True): with tempfile.TemporaryDirectory() as temp_config_dir: temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix=fileExtname) + dir=temp_config_dir, suffix=fileExtname, delete=False) temp_config_name = osp.basename(temp_config_file.name) # Substitute predefined variables @@ -96,6 +99,7 @@ def _file2dict(filename, use_predefined_variables=True): del sys.modules[temp_module_name] # close temp file temp_config_file.close() + os.remove(temp_config_file.name) return cfg_dict @staticmethod diff --git a/openstl/utils/main_utils.py b/openstl/utils/main_utils.py index 6374c831..6ec90850 100644 --- a/openstl/utils/main_utils.py +++ b/openstl/utils/main_utils.py @@ -128,13 +128,15 @@ def get_batch_size(H, W): def load_config(filename:str = None): """load and print config""" - print('loading config from ' + filename + ' ...') + print('loading config from ' + filename) try: configfile = Config(filename=filename) config = configfile._cfg_dict - except (FileNotFoundError, IOError): - config = dict() - print('warning: fail to load the config!') + except Exception as e: + raise Exception('error in load the config:{}'.format(e)) + # except (FileNotFoundError, IOError): + # config = dict() + # print('warning: fail to load the config!') return config @@ -177,4 +179,4 @@ def get_dist_info() -> Tuple[int, int]: else: rank = 0 world_size = 1 - return rank, world_size \ No newline at end of file + return rank, world_size diff --git a/tools/train.py b/tools/train.py index 8df10a6b..e665a39f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -29,10 +29,10 @@ config[attribute] = default_values[attribute] print('>'*35 + ' training ' + '<'*35) - exp = BaseExperiment(args) + exp = BaseExperiment(args, strategy='auto') rank, _ = get_dist_info() exp.train() if rank == 0: print('>'*35 + ' testing ' + '<'*35) - mse = exp.test() \ No newline at end of file + mse = exp.test()