11import argparse
2+ import os
23import os .path as osp
34import sys
45import warnings
56from typing import Any , List
67
78import cv2
89import numpy as np
10+ import torch .distributed as dist
911from easydict import EasyDict as edict
1012from mmcv import DictAction
11- from mmdet .apis import inference_detector , init_detector
12- from mmdet .utils .util_ymir import get_best_weight_file
13+ from mmcv .runner import init_dist
1314from tqdm import tqdm
14- from ymir_exc import dataset_reader as dr
15- from ymir_exc import env
1615from ymir_exc import result_writer as rw
17- from ymir_exc .util import YmirStage , get_merged_config , write_ymir_monitor_process
16+ from ymir_exc .util import (YmirStage , get_merged_config , write_ymir_monitor_process )
17+
18+ from mmdet .apis import inference_detector , init_detector
19+ from mmdet .apis .test import collect_results_gpu
20+ from mmdet .utils .util_ymir import get_best_weight_file
21+
22+ LOCAL_RANK = int (os .getenv ('LOCAL_RANK' , - 1 )) # https://pytorch.org/docs/stable/elastic/run.html
23+ RANK = int (os .getenv ('RANK' , - 1 ))
24+ WORLD_SIZE = int (os .getenv ('WORLD_SIZE' , 1 ))
1825
1926
2027def parse_option (cfg_options : str ) -> dict :
@@ -80,8 +87,9 @@ def __init__(self, cfg: edict):
8087 cfg_options = parse_option (options ) if options else None
8188
8289 # current infer can only use one gpu!!!
83- gpu_ids = cfg .param .get ('gpu_id' , '0' )
84- gpu_id = gpu_ids .split (',' )[0 ]
90+ # gpu_ids = cfg.param.get('gpu_id', '0')
91+ # gpu_id = gpu_ids.split(',')[0]
92+ gpu_id = max (0 , RANK )
8593 # build the model from a config file and a checkpoint file
8694 self .model = init_detector (config_file , checkpoint_file , device = f'cuda:{ gpu_id } ' , cfg_options = cfg_options )
8795
@@ -90,26 +98,47 @@ def infer(self, img):
9098
9199
92100def main ():
101+ if LOCAL_RANK != - 1 :
102+ init_dist (launcher = 'pytorch' , backend = "nccl" if dist .is_nccl_available () else "gloo" )
103+
93104 cfg = get_merged_config ()
94105
95- N = dr .items_count (env .DatasetType .CANDIDATE )
106+ with open (cfg .ymir .input .candidate_index_file , 'r' ) as f :
107+ images = [line .strip () for line in f .readlines ()]
108+
109+ max_barrier_times = len (images ) // WORLD_SIZE
110+ if RANK == - 1 :
111+ N = len (images )
112+ tbar = tqdm (images )
113+ else :
114+ images_rank = images [RANK ::WORLD_SIZE ]
115+ N = len (images_rank )
116+ if RANK == 0 :
117+ tbar = tqdm (images_rank )
118+ else :
119+ tbar = images_rank
96120 infer_result = dict ()
97121 model = YmirModel (cfg )
98- idx = - 1
99122
100123 # write infer result
101124 monitor_gap = max (1 , N // 100 )
102125 conf_threshold = float (cfg .param .conf_threshold )
103- for asset_path , _ in tqdm ( dr . item_paths ( dataset_type = env . DatasetType . CANDIDATE ) ):
126+ for idx , asset_path in enumerate ( tbar ):
104127 img = cv2 .imread (asset_path )
105128 result = model .infer (img )
106129 raw_anns = mmdet_result_to_ymir (result , cfg .param .class_names )
107130
131+ # batch-level sync, avoid 30min time-out error
132+ if WORLD_SIZE > 1 and idx < max_barrier_times :
133+ dist .barrier ()
134+
108135 infer_result [asset_path ] = [ann for ann in raw_anns if ann .score >= conf_threshold ]
109- idx += 1
110136
111137 if idx % monitor_gap == 0 :
112- write_ymir_monitor_process (cfg , task = 'infer' , naive_stage_percent = idx / N , stage = YmirStage .TASK )
138+ write_ymir_monitor_process (cfg , task = 'infer' , naive_stage_percent = idx / N , stage = YmirStage .TASK )
139+
140+ if WORLD_SIZE > 1 :
141+ infer_result = collect_results_gpu (infer_result , len (images ))
113142
114143 rw .write_infer_result (infer_result = infer_result )
115144 write_ymir_monitor_process (cfg , task = 'infer' , naive_stage_percent = 1.0 , stage = YmirStage .POSTPROCESS )
0 commit comments