Skip to content

Commit 1f5ea93

Browse files
committed
🚑️ [AAAI|Fix] Validation preprocess emb source
1 parent 0bb3080 commit 1f5ea93

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

yolo/aaai.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ def __init__(self, cfg: Config, model):
3131
self.contrastive_loss = NT_Xent
3232
self.cfg = cfg
3333
self.metric = meanBoxCoverScore()
34-
self.target_source = torch.load(
35-
f"{cfg.dataset.path}/target/from_{cfg.task.data.target_source}.pt", weights_only=False
34+
self.target_source_train = torch.load(
35+
f"{cfg.dataset.path}/target/from_{cfg.task.data.target_source}_train.pt", weights_only=False
36+
)
37+
self.target_source_val = torch.load(
38+
f"{cfg.dataset.path}/target/from_{cfg.task.data.target_source}_val.pt", weights_only=False
3639
)
3740

3841
def set_task(self, task):
@@ -45,7 +48,8 @@ def setup(self, stage):
4548
)
4649
self.loss_fn = AAAILoss(self.cfg.task.loss, self.vec2box)
4750
self.post_process = PostProcess(self.vec2box, self.cfg.task.validation.nms, aaai=True)
48-
self.target_source = self.target_source.to(self.device)
51+
self.target_source_train = self.target_source_train.to(self.device)
52+
self.target_source_val = self.target_source_val.to(self.device)
4953

5054
def forward(self, x, external=None, shortcut=None):
5155
return self.model(x, external, shortcut)
@@ -67,7 +71,7 @@ def training_step(self, batch, batch_idx):
6771

6872
if self.task == "detect":
6973
image_idx, pick_idx = idx_batch
70-
picked_vector = self.target_source[image_idx[:, None], pick_idx]
74+
picked_vector = self.target_source_train[image_idx[:, None], pick_idx]
7175

7276
origin_outputs = self(images, dict(target=picked_vector.permute(0, 2, 1)))
7377
detections = self.vec2box(origin_outputs["Main"])
@@ -129,7 +133,7 @@ def validation_step(self, batch, batch_idx):
129133

130134
puzzle_images, origin_idx, puzzle_idx, puzzles = puzzles_batch
131135
image_idx, pick_idx = idx_batch
132-
picked_vector = self.target_source[image_idx[:, None], pick_idx]
136+
picked_vector = self.target_source_val[image_idx[:, None], pick_idx]
133137

134138
origin_outputs = self(images, dict(target=picked_vector.permute(0, 2, 1)))
135139
H, W = images.shape[2:]

0 commit comments

Comments
 (0)