Skip to content

Commit f32e66a

Browse files
committed
update for ddp infer support
1 parent 3464c51 commit f32e66a

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

det-mmdetection-tmi/ymir_infer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main():
117117
tbar = tqdm(images_rank)
118118
else:
119119
tbar = images_rank
120-
infer_result = dict()
120+
infer_result_list = []
121121
model = YmirModel(cfg)
122122

123123
# write infer result
@@ -132,16 +132,19 @@ def main():
132132
if WORLD_SIZE > 1 and idx < max_barrier_times:
133133
dist.barrier()
134134

135-
infer_result[asset_path] = [ann for ann in raw_anns if ann.score >= conf_threshold]
135+
infer_result_list.append((asset_path, [ann for ann in raw_anns if ann.score >= conf_threshold]))
136136

137137
if idx % monitor_gap == 0:
138138
write_ymir_monitor_process(cfg, task='infer', naive_stage_percent=idx / N, stage=YmirStage.TASK)
139139

140140
if WORLD_SIZE > 1:
141-
infer_result = collect_results_gpu(infer_result, len(images))
141+
dist.barrier()
142+
infer_result_list = collect_results_gpu(infer_result_list, len(images))
142143

143-
rw.write_infer_result(infer_result=infer_result)
144-
write_ymir_monitor_process(cfg, task='infer', naive_stage_percent=1.0, stage=YmirStage.POSTPROCESS)
144+
if RANK in [0, -1]:
145+
infer_result_dict = {k: v for k, v in infer_result_list}
146+
rw.write_infer_result(infer_result=infer_result_dict)
147+
write_ymir_monitor_process(cfg, task='infer', naive_stage_percent=1.0, stage=YmirStage.POSTPROCESS)
145148
return 0
146149

147150

0 commit comments

Comments
 (0)