55import logging
66import os
77import os .path as osp
8- from typing import Any , List , Optional
8+ from typing import Any , Iterable , List , Optional
99
1010import mmcv
1111import yaml
1414from nptyping import NDArray , Shape , UInt8
1515from packaging .version import Version
1616from ymir_exc import result_writer as rw
17+ from ymir_exc .util import get_merged_config
1718
1819BBOX = NDArray [Shape ['*,4' ], Any ]
1920CV_IMAGE = NDArray [Shape ['*,*,3' ], UInt8 ]
2021
2122
22- def modify_mmdet_config ( mmdet_cfg : Config , ymir_cfg : edict ) -> Config :
23+ def modify_mmcv_config ( mmcv_cfg : Config , ymir_cfg : edict ) -> None :
2324 """
2425 useful for training process
2526 - modify dataset config
2627 - modify model output channel
2728 - modify epochs, checkpoint, tensorboard config
2829 """
30+ def recursive_modify_attribute (mmcv_cfg : Config , attribute_key : str , attribute_value : Any ):
31+ """
32+ recursive modify mmcv_cfg:
33+ 1. mmcv_cfg.attribute_key to attribute_value
34+ 2. mmcv_cfg.xxx.xxx.xxx.attribute_key to attribute_value (recursive)
35+ 3. mmcv_cfg.xxx[i].attribute_key to attribute_value (i=0, 1, 2 ...)
36+ 4. mmcv_cfg.xxx[i].xxx.xxx[j].attribute_key to attribute_value
37+ """
38+ for key in mmcv_cfg :
39+ if key == attribute_key :
40+ mmcv_cfg [key ] = attribute_value
41+ elif isinstance (mmcv_cfg [key ], Config ):
42+ recursive_modify_attribute (mmcv_cfg [key ], attribute_key , attribute_value )
43+ elif isinstance (mmcv_cfg [key ], Iterable ):
44+ for cfg in mmcv_cfg [key ]:
45+ if isinstance (cfg , Config ):
46+ recursive_modify_attribute (cfg , attribute_key , attribute_value )
47+
2948 # modify dataset config
3049 ymir_ann_files = dict (train = ymir_cfg .ymir .input .training_index_file ,
3150 val = ymir_cfg .ymir .input .val_index_file ,
@@ -35,8 +54,12 @@ def modify_mmdet_config(mmdet_cfg: Config, ymir_cfg: edict) -> Config:
3554 # so set smaller samples_per_gpu for validation
3655 samples_per_gpu = ymir_cfg .param .samples_per_gpu
3756 workers_per_gpu = ymir_cfg .param .workers_per_gpu
38- mmdet_cfg .data .samples_per_gpu = samples_per_gpu
39- mmdet_cfg .data .workers_per_gpu = workers_per_gpu
57+ mmcv_cfg .data .samples_per_gpu = samples_per_gpu
58+ mmcv_cfg .data .workers_per_gpu = workers_per_gpu
59+
60+ # modify model output channel
61+ num_classes = len (ymir_cfg .param .class_names )
62+ recursive_modify_attribute (mmcv_cfg .model , 'num_classes' , num_classes )
4063
4164 for split in ['train' , 'val' , 'test' ]:
4265 ymir_dataset_cfg = dict (type = 'YmirDataset' ,
@@ -47,7 +70,7 @@ def modify_mmdet_config(mmdet_cfg: Config, ymir_cfg: edict) -> Config:
4770 data_root = ymir_cfg .ymir .input .root_dir ,
4871 filter_empty_gt = False )
4972 # modify dataset config for `split`
50- mmdet_dataset_cfg = mmdet_cfg .data .get (split , None )
73+ mmdet_dataset_cfg = mmcv_cfg .data .get (split , None )
5174 if mmdet_dataset_cfg is None :
5275 continue
5376
@@ -63,33 +86,60 @@ def modify_mmdet_config(mmdet_cfg: Config, ymir_cfg: edict) -> Config:
6386 else :
6487 raise Exception (f'unsupported source dataset type { src_dataset_type } ' )
6588
66- # modify model output channel
67- mmdet_model_cfg = mmdet_cfg .model .bbox_head
68- mmdet_model_cfg .num_classes = len (ymir_cfg .param .class_names )
69-
7089 # modify epochs, checkpoint, tensorboard config
7190 if ymir_cfg .param .get ('max_epochs' , None ):
72- mmdet_cfg .runner .max_epochs = ymir_cfg .param .max_epochs
73- mmdet_cfg .checkpoint_config ['out_dir' ] = ymir_cfg .ymir .output .models_dir
91+ mmcv_cfg .runner .max_epochs = int ( ymir_cfg .param .max_epochs )
92+ mmcv_cfg .checkpoint_config ['out_dir' ] = ymir_cfg .ymir .output .models_dir
7493 tensorboard_logger = dict (type = 'TensorboardLoggerHook' , log_dir = ymir_cfg .ymir .output .tensorboard_dir )
75- if len (mmdet_cfg .log_config ['hooks' ]) <= 1 :
76- mmdet_cfg .log_config ['hooks' ].append (tensorboard_logger )
94+ if len (mmcv_cfg .log_config ['hooks' ]) <= 1 :
95+ mmcv_cfg .log_config ['hooks' ].append (tensorboard_logger )
7796 else :
78- mmdet_cfg .log_config ['hooks' ][1 ].update (tensorboard_logger )
97+ mmcv_cfg .log_config ['hooks' ][1 ].update (tensorboard_logger )
7998
99+ # TODO save only the best top-k model weight files.
80100 # modify evaluation and interval
81- interval = max (1 , mmdet_cfg .runner .max_epochs // 30 )
82- mmdet_cfg .evaluation .interval = interval
83- mmdet_cfg .evaluation .metric = ymir_cfg .param .get ('metric' , 'bbox' )
101+ val_interval : int = int (ymir_cfg .param .get ('val_interval' , 1 ))
102+ if val_interval > 0 :
103+ val_interval = min (val_interval , mmcv_cfg .runner .max_epochs )
104+ else :
105+ val_interval = 1
106+
107+ mmcv_cfg .evaluation .interval = val_interval
108+ mmcv_cfg .evaluation .metric = ymir_cfg .param .get ('metric' , 'bbox' )
109+
110+ # save best top-k model weights files
111+ # max_keep_ckpts <= 0 # save all checkpoints
112+ max_keep_ckpts : int = int (ymir_cfg .param .get ('max_keep_checkpoints' , 1 ))
113+ mmcv_cfg .checkpoint_config .interval = mmcv_cfg .evaluation .interval
114+ mmcv_cfg .checkpoint_config .max_keep_ckpts = max_keep_ckpts
115+
84116 # TODO Whether to evaluating the AP for each class
85117 # mmdet_cfg.evaluation.classwise = True
86118
87119 # fix DDP error
88- mmdet_cfg .find_unused_parameters = True
89- return mmdet_cfg
120+ mmcv_cfg .find_unused_parameters = True
121+
122+ # set work dir
123+ mmcv_cfg .work_dir = ymir_cfg .ymir .output .models_dir
124+
125+ args_options = ymir_cfg .param .get ("args_options" , '' )
126+ cfg_options = ymir_cfg .param .get ("cfg_options" , '' )
127+
128+ # auto load offered weight file if not set by user!
129+ if (args_options .find ('--resume-from' ) == - 1 and args_options .find ('--load-from' ) == - 1
130+ and cfg_options .find ('load_from' ) == - 1 and cfg_options .find ('resume_from' ) == - 1 ): # noqa: E129
131+
132+ weight_file = get_best_weight_file (ymir_cfg )
133+ if weight_file :
134+ if cfg_options :
135+ cfg_options += f' load_from={ weight_file } '
136+ else :
137+ cfg_options = f'load_from={ weight_file } '
138+ else :
139+ logging .warning ('no weight file used for training!' )
90140
91141
92- def get_weight_file (cfg : edict ) -> str :
142+ def get_best_weight_file (cfg : edict ) -> str :
93143 """
94144 return the weight file path by priority
95145 find weight file in cfg.param.pretrained_model_params or cfg.param.model_params_path
@@ -118,6 +168,7 @@ def get_weight_file(cfg: edict) -> str:
118168 if cfg .ymir .run_training :
119169 weight_files = [f for f in glob .glob ('/weights/**/*' , recursive = True ) if f .endswith (('.pth' , '.pt' ))]
120170
171+ # load pretrained model weight for yolox only
121172 model_name_splits = osp .basename (cfg .param .config_file ).split ('_' )
122173 if len (weight_files ) > 0 and model_name_splits [0 ] == 'yolox' :
123174 yolox_weight_files = [
@@ -145,6 +196,30 @@ def write_ymir_training_result(last: bool = False, key_score: Optional[float] =
145196 _write_ancient_ymir_training_result (key_score )
146197
147198
199+ def get_topk_checkpoints (files : List [str ], k : int ) -> List [str ]:
200+ """
201+ keep topk checkpoint files, remove other files.
202+ """
203+ checkpoints_files = [f for f in files if f .endswith (('.pth' , '.pt' ))]
204+
205+ best_pth_files = [f for f in checkpoints_files if osp .basename (f ).startswith ('best_' )]
206+ if len (best_pth_files ) > 0 :
207+ # newest first
208+ topk_best_pth_files = sorted (best_pth_files , key = os .path .getctime , reverse = True )
209+ else :
210+ topk_best_pth_files = []
211+
212+ epoch_pth_files = [f for f in checkpoints_files if osp .basename (f ).startswith (('epoch_' , 'iter_' ))]
213+ if len (epoch_pth_files ) > 0 :
214+ topk_epoch_pth_files = sorted (epoch_pth_files , key = os .path .getctime , reverse = True )
215+ else :
216+ topk_epoch_pth_files = []
217+
218+ # python will check the length of list
219+ return topk_best_pth_files [0 :k ] + topk_epoch_pth_files [0 :k ]
220+
221+
222+ # TODO save topk checkpoints, fix invalid stage due to delete checkpoint
148223def _write_latest_ymir_training_result (last : bool = False , key_score : Optional [float ] = None ):
149224 if key_score :
150225 logging .info (f'key_score is { key_score } ' )
@@ -165,6 +240,11 @@ def _write_latest_ymir_training_result(last: bool = False, key_score: Optional[f
165240
166241 if last :
167242 # save all output file
243+ ymir_cfg = get_merged_config ()
244+ max_keep_checkpoints = int (ymir_cfg .param .get ('max_keep_checkpoints' , 1 ))
245+ if max_keep_checkpoints > 0 :
246+ topk_checkpoints = get_topk_checkpoints (result_files , max_keep_checkpoints )
247+ result_files = [f for f in result_files if not f .endswith (('.pth' , '.pt' ))] + topk_checkpoints
168248 rw .write_model_stage (files = result_files , mAP = float (map ), stage_name = 'last' )
169249 else :
170250 # save newest weight file in format epoch_xxx.pth or iter_xxx.pth
@@ -201,13 +281,17 @@ def _write_ancient_ymir_training_result(key_score: Optional[float] = None):
201281 # eval_result may be empty dict {}.
202282 map = eval_result .get ('bbox_mAP_50' , 0 )
203283
204- WORK_DIR = os .getenv ('YMIR_MODELS_DIR' )
205- if WORK_DIR is None or not osp .isdir (WORK_DIR ):
206- raise Exception (f'please set valid environment variable YMIR_MODELS_DIR, invalid directory { WORK_DIR } ' )
284+ ymir_cfg = get_merged_config ()
285+ WORK_DIR = ymir_cfg .ymir .output .models_dir
207286
208287 # assert only one model config file in work_dir
209288 result_files = [osp .basename (f ) for f in glob .glob (osp .join (WORK_DIR , '*' )) if osp .basename (f ) != 'result.yaml' ]
210289
290+ max_keep_checkpoints = int (ymir_cfg .param .get ('max_keep_checkpoints' , 1 ))
291+ if max_keep_checkpoints > 0 :
292+ topk_checkpoints = get_topk_checkpoints (result_files , max_keep_checkpoints )
293+ result_files = [f for f in result_files if not f .endswith (('.pth' , '.pt' ))] + topk_checkpoints
294+
211295 training_result_file = osp .join (WORK_DIR , 'result.yaml' )
212296 if osp .exists (training_result_file ):
213297 with open (training_result_file , 'r' ) as f :
0 commit comments