From 3f480f972b573a968e31deb4c8f44d0bcd9f5c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anm=E5=8D=8A=E5=A4=8F?= <67933480+Anm-pinellia@users.noreply.github.com> Date: Thu, 30 May 2024 19:59:35 +0800 Subject: [PATCH 1/5] Update exp.py Fix the problem of current trainer cannot parser corresponding dataloaders. --- openstl/api/exp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openstl/api/exp.py b/openstl/api/exp.py index 77f233b5..8325690d 100644 --- a/openstl/api/exp.py +++ b/openstl/api/exp.py @@ -92,13 +92,13 @@ def _get_data(self, dataloaders=None): return BaseDataModule(train_loader, vali_loader, test_loader) def train(self): - self.trainer.fit(self.method, self.data) + self.trainer.fit(self.method, self.data.train_loader) def test(self): if self.args.test == True: ckpt = torch.load(osp.join(self.save_dir, 'checkpoints', 'best.ckpt')) self.method.load_state_dict(ckpt['state_dict']) - self.trainer.test(self.method, self.data) + self.trainer.test(self.method, self.data.test_loader) def display_method_info(self, args): """Plot the basic infomation of supported methods""" @@ -142,4 +142,4 @@ def display_method_info(self, args): fps = 'Throughputs of {}: {:.3f}\n'.format(args.method, fps) else: fps = '' - return info, flops, fps, dash_line \ No newline at end of file + return info, flops, fps, dash_line From ee35490297002e5c30842f5654b0086691d9baae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anm=E5=8D=8A=E5=A4=8F?= <67933480+Anm-pinellia@users.noreply.github.com> Date: Fri, 31 May 2024 11:24:40 +0800 Subject: [PATCH 2/5] Update config_utils.py modify some code to make it compatible with windows. Specifically, after opening a temporary file, windows system will remove it after close it. So it needs to modify some params during create the temporayfile. (e.g. delte=False, open and delete temporaryfile manually) --- openstl/utils/config_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/openstl/utils/config_utils.py b/openstl/utils/config_utils.py index b5714449..34e88c33 100644 --- a/openstl/utils/config_utils.py +++ b/openstl/utils/config_utils.py @@ -5,6 +5,7 @@ import sys import ast from importlib import import_module +import os ''' Thanks the code from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py wrote by Open-MMLab. @@ -58,8 +59,10 @@ def _substitute_predefined_vars(filename, temp_config_name): regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' value = value.replace('\\', '/') config_file = re.sub(regexp, value, config_file) - with open(temp_config_name, 'w') as tmp_config_file: - tmp_config_file.write(config_file) + # with open(temp_config_name, 'w') as tmp_config_file: + # tmp_config_file.write(config_file) + temp_config_file = open(temp_config_name, 'w') + temp_config_file.write(config_file) @staticmethod def _file2dict(filename, use_predefined_variables=True): @@ -69,9 +72,9 @@ def _file2dict(filename, use_predefined_variables=True): if fileExtname not in ['.py']: raise IOError('Only py type are supported now!') - with tempfile.TemporaryDirectory() as temp_config_dir: + with tempfile.TemporaryDirectory(dir=r'G:\temp dirs\code_temp') as temp_config_dir: temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix=fileExtname) + dir=temp_config_dir, suffix=fileExtname, delete=False) temp_config_name = osp.basename(temp_config_file.name) # Substitute predefined variables @@ -96,6 +99,7 @@ def _file2dict(filename, use_predefined_variables=True): del sys.modules[temp_module_name] # close temp file temp_config_file.close() + os.remove(temp_config_file.name) return cfg_dict @staticmethod From a79e7553f39e1305740797a7d53298be51d66980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anm=E5=8D=8A=E5=A4=8F?= <67933480+Anm-pinellia@users.noreply.github.com> Date: Fri, 31 May 2024 11:26:11 +0800 Subject: [PATCH 3/5] Update config_utils.py modify some code to make it compatible with windows. Specifically, after opening a temporary file, windows system will remove it after close it. So it needs to modify some params during create the temporayfile. (e.g. delte=False, open and delete temporaryfile manually) --- openstl/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openstl/utils/config_utils.py b/openstl/utils/config_utils.py index 34e88c33..71b36588 100644 --- a/openstl/utils/config_utils.py +++ b/openstl/utils/config_utils.py @@ -72,7 +72,7 @@ def _file2dict(filename, use_predefined_variables=True): if fileExtname not in ['.py']: raise IOError('Only py type are supported now!') - with tempfile.TemporaryDirectory(dir=r'G:\temp dirs\code_temp') as temp_config_dir: + with tempfile.TemporaryDirectory() as temp_config_dir: temp_config_file = tempfile.NamedTemporaryFile( dir=temp_config_dir, suffix=fileExtname, delete=False) temp_config_name = osp.basename(temp_config_file.name) From c5c9a6bdbdc6a56b6605b72b5f6910157fc3bfc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anm=E5=8D=8A=E5=A4=8F?= <67933480+Anm-pinellia@users.noreply.github.com> Date: Fri, 31 May 2024 11:30:10 +0800 Subject: [PATCH 4/5] Update main_utils.py make it easy to debug, instead of giving a print. (Config file should be prior than arguments, which can be seen in mmopenlab frameworks). --- openstl/utils/main_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/openstl/utils/main_utils.py b/openstl/utils/main_utils.py index 6374c831..6ec90850 100644 --- a/openstl/utils/main_utils.py +++ b/openstl/utils/main_utils.py @@ -128,13 +128,15 @@ def get_batch_size(H, W): def load_config(filename:str = None): """load and print config""" - print('loading config from ' + filename + ' ...') + print('loading config from ' + filename) try: configfile = Config(filename=filename) config = configfile._cfg_dict - except (FileNotFoundError, IOError): - config = dict() - print('warning: fail to load the config!') + except Exception as e: + raise Exception('error in load the config:{}'.format(e)) + # except (FileNotFoundError, IOError): + # config = dict() + # print('warning: fail to load the config!') return config @@ -177,4 +179,4 @@ def get_dist_info() -> Tuple[int, int]: else: rank = 0 world_size = 1 - return rank, world_size \ No newline at end of file + return rank, world_size From 0384bd68527b614dae14295d74b0a10224583c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anm=E5=8D=8A=E5=A4=8F?= <67933480+Anm-pinellia@users.noreply.github.com> Date: Fri, 31 May 2024 11:32:10 +0800 Subject: [PATCH 5/5] Update train.py To compatible with windows, default strategy should better to be auto rather than ddp. (windows only support gloo instead of nccl to manage gpus). --- tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index 8df10a6b..e665a39f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -29,10 +29,10 @@ config[attribute] = default_values[attribute] print('>'*35 + ' training ' + '<'*35) - exp = BaseExperiment(args) + exp = BaseExperiment(args, strategy='auto') rank, _ = get_dist_info() exp.train() if rank == 0: print('>'*35 + ' testing ' + '<'*35) - mse = exp.test() \ No newline at end of file + mse = exp.test()