1414import torch .distributed as dist
1515import torch .utils .data as td
1616from easydict import EasyDict as edict
17+ from mining .util import YmirDataset , load_image_file
1718from tqdm import tqdm
19+ from utils .ymir_yolov5 import YmirYolov5
1820from ymir_exc import result_writer as rw
1921from ymir_exc .util import YmirStage , get_merged_config
2022
21- from mining .util import (YmirDataset , collate_fn_with_fake_ann , load_image_file , load_image_file_with_ann ,
22- update_consistency )
23- from utils .general import scale_coords
24- from utils .ymir_yolov5 import YmirYolov5
25-
2623LOCAL_RANK = int (os .getenv ('LOCAL_RANK' , - 1 )) # https://pytorch.org/docs/stable/elastic/run.html
2724RANK = int (os .getenv ('RANK' , - 1 ))
2825WORLD_SIZE = int (os .getenv ('WORLD_SIZE' , 1 ))
@@ -58,9 +55,7 @@ def run(ymir_cfg: edict, ymir_yolov5: YmirYolov5):
5855 pin_memory = ymir_yolov5 .pin_memory ,
5956 drop_last = False )
6057
61- results = []
6258 mining_results = dict ()
63- beta = 1.3
6459 dataset_size = len (images_rank )
6560 pbar = tqdm (origin_dataset_loader ) if RANK == 0 else origin_dataset_loader
6661 for idx , batch in enumerate (pbar ):
@@ -73,18 +68,16 @@ def run(ymir_cfg: edict, ymir_yolov5: YmirYolov5):
7368
7469 if RANK in [- 1 , 0 ]:
7570 ymir_yolov5 .write_monitor_logger (stage = YmirStage .TASK , p = idx * batch_size_per_gpu / dataset_size )
76- preprocess_image_shape = batch ['image' ].shape [2 :]
7771 for inner_idx , det in enumerate (pred ): # per image
78- result_per_image = []
7972 image_file = batch ['image_file' ][inner_idx ]
8073 if len (det ):
8174 conf = det [:, 4 ].data .cpu ().numpy ()
82- mining_results [image_file ] = - np .sum (conf * np .log2 (conf ))
75+ mining_results [image_file ] = - np .sum (conf * np .log2 (conf ))
8376 else :
8477 mining_results [image_file ] = - 10
8578 continue
8679
87- torch .save (mining_results , f'/out/mining_results_{ RANK } .pt' )
80+ torch .save (mining_results , f'/out/mining_results_{ max ( 0 , RANK ) } .pt' )
8881
8982
9083def main () -> int :
@@ -99,7 +92,7 @@ def main() -> int:
9992 run (ymir_cfg , ymir_yolov5 )
10093
10194 # wait all process to save the mining result
102- if LOCAL_RANK != - 1 :
95+ if WORLD_SIZE > 1 :
10396 dist .barrier ()
10497
10598 if RANK in [0 , - 1 ]:
@@ -112,10 +105,6 @@ def main() -> int:
112105 for img_file , score in result .items ():
113106 ymir_mining_result .append ((img_file , score ))
114107 rw .write_mining_result (mining_result = ymir_mining_result )
115-
116- if LOCAL_RANK != - 1 :
117- print (f'rank: { RANK } , start destroy process group' )
118- # dist.destroy_process_group()
119108 return 0
120109
121110
0 commit comments