1818
1919from yolo import Config , PostProcess , create_converter , create_model , draw_bboxes
2020from yolo .tools .data_loader import AAAIDataset
21- from yolo .tools .loss_functions import AAAILoss , NT_Xent
21+ from yolo .tools .loss_functions import AAAILoss , NT_Xent , meanBoxCoverScore
2222from yolo .utils .logging_utils import YOLORichModelSummary , setup
2323from yolo .utils .model_utils import create_optimizer , create_scheduler
2424
@@ -30,6 +30,7 @@ def __init__(self, cfg: Config, model):
3030 self .construct_loss = nn .MSELoss ()
3131 self .contrastive_loss = NT_Xent
3232 self .cfg = cfg
33+ self .metric = meanBoxCoverScore ()
3334
3435 def set_task (self , task ):
3536 self .task = task
@@ -140,7 +141,21 @@ def validation_step(self, batch, batch_idx):
140141
141142 origin_outputs = self (images , dict (target = picked_vector .permute (0 , 2 , 1 )))
142143 H , W = images .shape [2 :]
143- predicts = self .post_process (origin_outputs , image_size = [W , H ], target = picked_vector )
144+
145+ unpacked_outputs = []
146+ num_target = self .cfg .task .data .num_target
147+ repeat_num = [(num_target , 1 , 1 , 1 ), (num_target , 1 , 1 , 1 , 1 )]
148+ for batch_idx in range (batch_size ):
149+ batch_sample = [
150+ [x [[batch_idx ]].repeat (* repeat_num [idx % 2 ]) for idx , x in enumerate (res_list )]
151+ for res_list in origin_outputs ["Main" ]
152+ ]
153+ unpacked_outputs .append (batch_sample )
154+
155+ for outputs , tar_bbox , tar_vec in zip (unpacked_outputs , bbox , picked_vector ):
156+ outputs = {"Main" : outputs }
157+ predicts = self .post_process (outputs , image_size = [W , H ], target = tar_vec [:, None ])
158+ self .metric (predicts , tar_bbox )
144159
145160 return predicts
146161
@@ -165,6 +180,10 @@ def on_validation_epoch_end(self):
165180 logger = self .logger .experiment
166181 logger .add_histogram ("Validation Error Distribution" , Tensor (self .FIDs ), self .current_epoch )
167182
183+ epoch_metric = self .metric .compute ()
184+ self .log ("detection_metric" , epoch_metric , prog_bar = True , sync_dist = True , rank_zero_only = True )
185+ self .metric .reset ()
186+
168187 def configure_optimizers (self ):
169188 optimizer = create_optimizer (self .model , self .cfg .task .optimizer )
170189 scheduler = create_scheduler (optimizer , self .cfg .task .scheduler )
0 commit comments