Skip to content

Commit 46a2439

Browse files
authored
Merge pull request #11 from modelai/nanfei
add entropy,random for yolov5
2 parents b294d3b + 59420c2 commit 46a2439

3 files changed

Lines changed: 287 additions & 0 deletions

File tree

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Consistency-based Active Learning for Object Detection CVPR 2022 workshop
3+
official code: https://github.com/we1pingyu/CALD/blob/master/cald_train.py
4+
"""
5+
import sys
6+
from typing import Dict, List, Tuple
7+
8+
import cv2
9+
import numpy as np
10+
from easydict import EasyDict as edict
11+
from mining.data_augment import cutout, horizontal_flip, intersect, resize, rotate
12+
from nptyping import NDArray
13+
from scipy.stats import entropy
14+
from tqdm import tqdm
15+
from utils.ymir_yolov5 import BBOX, CV_IMAGE, YmirYolov5
16+
from ymir_exc import dataset_reader as dr
17+
from ymir_exc import env, monitor
18+
from ymir_exc import result_writer as rw
19+
from ymir_exc.util import YmirStage, get_merged_config, get_ymir_process
20+
21+
def split_result(result: NDArray) -> Tuple[BBOX, NDArray, NDArray]:
22+
if len(result) > 0:
23+
bboxes = result[:, :4].astype(np.int32)
24+
conf = result[:, 4]
25+
class_id = result[:, 5]
26+
else:
27+
bboxes = np.zeros(shape=(0, 4), dtype=np.int32)
28+
conf = np.zeros(shape=(0, 1), dtype=np.float32)
29+
class_id = np.zeros(shape=(0, 1), dtype=np.int32)
30+
31+
return bboxes, conf, class_id
32+
33+
class MiningEntropy(YmirYolov5):
34+
def __init__(self, cfg: edict):
35+
super().__init__(cfg)
36+
37+
if cfg.ymir.run_mining and cfg.ymir.run_infer:
38+
# multiple task, run mining first, infer later
39+
mining_task_idx = 0
40+
task_num = 2
41+
else:
42+
mining_task_idx = 0
43+
task_num = 1
44+
45+
self.task_idx = mining_task_idx
46+
self.task_num = task_num
47+
48+
def mining(self) -> List:
49+
N = dr.items_count(env.DatasetType.CANDIDATE)
50+
monitor_gap = max(1, N // 1000)
51+
idx = -1
52+
beta = 1.3
53+
mining_result = []
54+
for asset_path, _ in tqdm(dr.item_paths(dataset_type=env.DatasetType.CANDIDATE)):
55+
img = cv2.imread(asset_path)
56+
# xyxy,conf,cls
57+
result = self.predict(img,nms=False)
58+
bboxes, conf, _ = split_result(result)
59+
if len(result) == 0:
60+
# no result for the image without augmentation
61+
mining_result.append((asset_path, -10))
62+
continue
63+
mining_result.append((asset_path,-np.sum(conf*np.log2(conf))))
64+
idx += 1
65+
if idx % monitor_gap == 0:
66+
percent = get_ymir_process(stage=YmirStage.TASK, p=idx / N,
67+
task_idx=self.task_idx, task_num=self.task_num)
68+
monitor.write_monitor_logger(percent=percent)
69+
70+
return mining_result
71+
72+
def main():
73+
cfg = get_merged_config()
74+
miner = MiningEntropy(cfg)
75+
mining_result = miner.mining()
76+
rw.write_mining_result(mining_result=mining_result)
77+
78+
return 0
79+
80+
81+
if __name__ == "__main__":
82+
sys.exit(main())
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""use fake DDP to infer
2+
1. split data with `images_rank = images[RANK::WORLD_SIZE]`
3+
2. infer on the origin dataset
4+
3. infer on the augmentation dataset
5+
4. save splited mining result with `torch.save(results, f'/out/mining_results_{RANK}.pt')`
6+
5. merge mining result
7+
"""
8+
import os
9+
import sys
10+
from functools import partial
11+
12+
import numpy as np
13+
import torch
14+
import torch.distributed as dist
15+
import torch.utils.data as td
16+
from easydict import EasyDict as edict
17+
from tqdm import tqdm
18+
from ymir_exc import result_writer as rw
19+
from ymir_exc.util import YmirStage, get_merged_config
20+
21+
from mining.util import (YmirDataset, collate_fn_with_fake_ann, load_image_file, load_image_file_with_ann,
22+
update_consistency)
23+
from utils.general import scale_coords
24+
from utils.ymir_yolov5 import YmirYolov5
25+
26+
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
27+
RANK = int(os.getenv('RANK', -1))
28+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
29+
30+
31+
def run(ymir_cfg: edict, ymir_yolov5: YmirYolov5):
32+
# eg: gpu_id = 1,3,5,7 for LOCAL_RANK = 2, will use gpu 5.
33+
gpu = LOCAL_RANK if LOCAL_RANK >= 0 else 0
34+
device = torch.device('cuda', gpu)
35+
ymir_yolov5.to(device)
36+
37+
load_fn = partial(load_image_file, img_size=ymir_yolov5.img_size, stride=ymir_yolov5.stride)
38+
batch_size_per_gpu: int = ymir_yolov5.batch_size_per_gpu
39+
gpu_count: int = ymir_yolov5.gpu_count
40+
cpu_count: int = os.cpu_count() or 1
41+
num_workers_per_gpu = min([
42+
cpu_count // max(gpu_count, 1), batch_size_per_gpu if batch_size_per_gpu > 1 else 0,
43+
ymir_yolov5.num_workers_per_gpu
44+
])
45+
46+
with open(ymir_cfg.ymir.input.candidate_index_file, 'r') as f:
47+
images = [line.strip() for line in f.readlines()]
48+
49+
max_barrier_times = (len(images) // max(1, WORLD_SIZE)) // batch_size_per_gpu
50+
# origin dataset
51+
images_rank = images[RANK::WORLD_SIZE]
52+
origin_dataset = YmirDataset(images_rank, load_fn=load_fn)
53+
origin_dataset_loader = td.DataLoader(origin_dataset,
54+
batch_size=batch_size_per_gpu,
55+
shuffle=False,
56+
sampler=None,
57+
num_workers=num_workers_per_gpu,
58+
pin_memory=ymir_yolov5.pin_memory,
59+
drop_last=False)
60+
61+
results = []
62+
mining_results = dict()
63+
beta = 1.3
64+
dataset_size = len(images_rank)
65+
pbar = tqdm(origin_dataset_loader) if RANK == 0 else origin_dataset_loader
66+
for idx, batch in enumerate(pbar):
67+
# batch-level sync, avoid 30min time-out error
68+
if LOCAL_RANK != -1 and idx < max_barrier_times:
69+
dist.barrier()
70+
71+
with torch.no_grad():
72+
pred = ymir_yolov5.forward(batch['image'].float().to(device), nms=False)
73+
74+
if RANK in [-1, 0]:
75+
ymir_yolov5.write_monitor_logger(stage=YmirStage.TASK, p=idx * batch_size_per_gpu / dataset_size)
76+
preprocess_image_shape = batch['image'].shape[2:]
77+
for inner_idx, det in enumerate(pred): # per image
78+
result_per_image = []
79+
image_file = batch['image_file'][inner_idx]
80+
if len(det):
81+
conf = det[:, 4].data.cpu().numpy()
82+
mining_results[image_file] = -np.sum(conf*np.log2(conf))
83+
else:
84+
mining_results[image_file] = -10
85+
continue
86+
87+
torch.save(mining_results, f'/out/mining_results_{RANK}.pt')
88+
89+
90+
def main() -> int:
91+
ymir_cfg = get_merged_config()
92+
ymir_yolov5 = YmirYolov5(ymir_cfg, task='mining')
93+
94+
if LOCAL_RANK != -1:
95+
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
96+
torch.cuda.set_device(LOCAL_RANK)
97+
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
98+
99+
run(ymir_cfg, ymir_yolov5)
100+
101+
# wait all process to save the mining result
102+
if LOCAL_RANK != -1:
103+
dist.barrier()
104+
105+
if RANK in [0, -1]:
106+
results = []
107+
for rank in range(WORLD_SIZE):
108+
results.append(torch.load(f'/out/mining_results_{rank}.pt'))
109+
110+
ymir_mining_result = []
111+
for result in results:
112+
for img_file, score in result.items():
113+
ymir_mining_result.append((img_file, score))
114+
rw.write_mining_result(mining_result=ymir_mining_result)
115+
116+
if LOCAL_RANK != -1:
117+
print(f'rank: {RANK}, start destroy process group')
118+
# dist.destroy_process_group()
119+
return 0
120+
121+
122+
if __name__ == '__main__':
123+
sys.exit(main())
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""use fake DDP to infer
2+
1. split data with `images_rank = images[RANK::WORLD_SIZE]`
3+
2. infer on the origin dataset
4+
3. infer on the augmentation dataset
5+
4. save splited mining result with `torch.save(results, f'/out/mining_results_{RANK}.pt')`
6+
5. merge mining result
7+
"""
8+
import os
9+
import random
10+
import sys
11+
from functools import partial
12+
13+
import numpy as np
14+
import torch
15+
import torch.distributed as dist
16+
import torch.utils.data as td
17+
from easydict import EasyDict as edict
18+
from tqdm import tqdm
19+
from ymir_exc import result_writer as rw
20+
from ymir_exc.util import YmirStage, get_merged_config
21+
22+
from mining.util import (YmirDataset, collate_fn_with_fake_ann, load_image_file, load_image_file_with_ann,
23+
update_consistency)
24+
from utils.general import scale_coords
25+
from utils.ymir_yolov5 import YmirYolov5
26+
27+
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
28+
RANK = int(os.getenv('RANK', -1))
29+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
30+
31+
32+
def run(ymir_cfg: edict, ymir_yolov5: YmirYolov5):
33+
# eg: gpu_id = 1,3,5,7 for LOCAL_RANK = 2, will use gpu 5.
34+
gpu = LOCAL_RANK if LOCAL_RANK >= 0 else 0
35+
device = torch.device('cuda', gpu)
36+
ymir_yolov5.to(device)
37+
38+
with open(ymir_cfg.ymir.input.candidate_index_file, 'r') as f:
39+
images = [line.strip() for line in f.readlines()]
40+
41+
images_rank = images[RANK::WORLD_SIZE]
42+
mining_results=dict()
43+
for image in images_rank:
44+
mining_results[image] = random.random()
45+
46+
torch.save(mining_results, f'/out/mining_results_{RANK}.pt')
47+
48+
49+
def main() -> int:
50+
ymir_cfg = get_merged_config()
51+
ymir_yolov5 = YmirYolov5(ymir_cfg, task='mining')
52+
53+
if LOCAL_RANK != -1:
54+
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
55+
torch.cuda.set_device(LOCAL_RANK)
56+
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
57+
58+
run(ymir_cfg, ymir_yolov5)
59+
60+
# wait all process to save the mining result
61+
if LOCAL_RANK != -1:
62+
dist.barrier()
63+
64+
if RANK in [0, -1]:
65+
results = []
66+
for rank in range(WORLD_SIZE):
67+
results.append(torch.load(f'/out/mining_results_{rank}.pt'))
68+
69+
ymir_mining_result = []
70+
for result in results:
71+
for img_file, score in result.items():
72+
ymir_mining_result.append((img_file, score))
73+
rw.write_mining_result(mining_result=ymir_mining_result)
74+
75+
if LOCAL_RANK != -1:
76+
print(f'rank: {RANK}, start destroy process group')
77+
# dist.destroy_process_group()
78+
return 0
79+
80+
81+
if __name__ == '__main__':
82+
sys.exit(main())

0 commit comments

Comments
 (0)