Skip to content

Commit a72db1e

Browse files
nathalie000henrytsui000
authored andcommitted
✨ [AAAI|Add] meanBoxCoverScore validation metric
1 parent 899b78c commit a72db1e

2 files changed

Lines changed: 43 additions & 2 deletions

File tree

yolo/aaai.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from yolo import Config, PostProcess, create_converter, create_model, draw_bboxes
2020
from yolo.tools.data_loader import AAAIDataset
21-
from yolo.tools.loss_functions import AAAILoss, NT_Xent
21+
from yolo.tools.loss_functions import AAAILoss, NT_Xent, meanBoxCoverScore
2222
from yolo.utils.logging_utils import YOLORichModelSummary, setup
2323
from yolo.utils.model_utils import create_optimizer, create_scheduler
2424

@@ -30,6 +30,7 @@ def __init__(self, cfg: Config, model):
3030
self.construct_loss = nn.MSELoss()
3131
self.contrastive_loss = NT_Xent
3232
self.cfg = cfg
33+
self.metric = meanBoxCoverScore()
3334

3435
def set_task(self, task):
3536
self.task = task
@@ -140,7 +141,21 @@ def validation_step(self, batch, batch_idx):
140141

141142
origin_outputs = self(images, dict(target=picked_vector.permute(0, 2, 1)))
142143
H, W = images.shape[2:]
143-
predicts = self.post_process(origin_outputs, image_size=[W, H], target=picked_vector)
144+
145+
unpacked_outputs = []
146+
num_target = self.cfg.task.data.num_target
147+
repeat_num = [(num_target, 1, 1, 1), (num_target, 1, 1, 1, 1)]
148+
for batch_idx in range(batch_size):
149+
batch_sample = [
150+
[x[[batch_idx]].repeat(*repeat_num[idx % 2]) for idx, x in enumerate(res_list)]
151+
for res_list in origin_outputs["Main"]
152+
]
153+
unpacked_outputs.append(batch_sample)
154+
155+
for outputs, tar_bbox, tar_vec in zip(unpacked_outputs, bbox, picked_vector):
156+
outputs = {"Main": outputs}
157+
predicts = self.post_process(outputs, image_size=[W, H], target=tar_vec[:, None])
158+
self.metric(predicts, tar_bbox)
144159

145160
return predicts
146161

@@ -165,6 +180,10 @@ def on_validation_epoch_end(self):
165180
logger = self.logger.experiment
166181
logger.add_histogram("Validation Error Distribution", Tensor(self.FIDs), self.current_epoch)
167182

183+
epoch_metric = self.metric.compute()
184+
self.log("detection_metric", epoch_metric, prog_bar=True, sync_dist=True, rank_zero_only=True)
185+
self.metric.reset()
186+
168187
def configure_optimizers(self):
169188
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
170189
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)

yolo/tools/loss_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import Tensor, nn
77
from torch.nn import BCEWithLogitsLoss, MSELoss
88
from torch.nn.functional import cosine_similarity, cross_entropy
9+
from torchmetrics import MeanMetric
910

1011
from yolo.config.config import Config, LossConfig, MatcherConfig
1112
from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
@@ -354,3 +355,24 @@ def create_loss_function(cfg: Config, vec2box) -> DualLoss:
354355
loss_function = DualLoss(cfg, vec2box)
355356
logger.info(":white_check_mark: Success load loss function")
356357
return loss_function
358+
359+
360+
class meanBoxCoverScore(MeanMetric):
361+
def __init__(self, hit_iou=0.7, **kwargs):
362+
super().__init__(**kwargs)
363+
self.hit_iou = hit_iou
364+
365+
def update(self, predicts: List[Tensor], bboxes: Tensor) -> None:
366+
for idx, predict in enumerate(predicts):
367+
if predict.size(0) == 0:
368+
super().update(0)
369+
continue
370+
predict = predict[predict[:, 5].argsort(descending=True)]
371+
iou_matrix = calculate_iou(predict[:, 1:5], bboxes[idx])
372+
best_hit = (iou_matrix > self.hit_iou).float()
373+
rank_score = best_hit.argmax().where(best_hit.any(), float("inf"))
374+
score = (1 / (rank_score + 1)).mean()
375+
super().update(score)
376+
377+
def compute(self) -> Tensor:
378+
return super().compute()

0 commit comments

Comments
 (0)