@@ -31,8 +31,11 @@ def __init__(self, cfg: Config, model):
3131 self .contrastive_loss = NT_Xent
3232 self .cfg = cfg
3333 self .metric = meanBoxCoverScore ()
34- self .target_source = torch .load (
35- f"{ cfg .dataset .path } /target/from_{ cfg .task .data .target_source } .pt" , weights_only = False
34+ self .target_source_train = torch .load (
35+ f"{ cfg .dataset .path } /target/from_{ cfg .task .data .target_source } _train.pt" , weights_only = False
36+ )
37+ self .target_source_val = torch .load (
38+ f"{ cfg .dataset .path } /target/from_{ cfg .task .data .target_source } _val.pt" , weights_only = False
3639 )
3740
3841 def set_task (self , task ):
@@ -45,7 +48,8 @@ def setup(self, stage):
4548 )
4649 self .loss_fn = AAAILoss (self .cfg .task .loss , self .vec2box )
4750 self .post_process = PostProcess (self .vec2box , self .cfg .task .validation .nms , aaai = True )
48- self .target_source = self .target_source .to (self .device )
51+ self .target_source_train = self .target_source_train .to (self .device )
52+ self .target_source_val = self .target_source_val .to (self .device )
4953
5054 def forward (self , x , external = None , shortcut = None ):
5155 return self .model (x , external , shortcut )
@@ -67,7 +71,7 @@ def training_step(self, batch, batch_idx):
6771
6872 if self .task == "detect" :
6973 image_idx , pick_idx = idx_batch
70- picked_vector = self .target_source [image_idx [:, None ], pick_idx ]
74+ picked_vector = self .target_source_train [image_idx [:, None ], pick_idx ]
7175
7276 origin_outputs = self (images , dict (target = picked_vector .permute (0 , 2 , 1 )))
7377 detections = self .vec2box (origin_outputs ["Main" ])
@@ -129,7 +133,7 @@ def validation_step(self, batch, batch_idx):
129133
130134 puzzle_images , origin_idx , puzzle_idx , puzzles = puzzles_batch
131135 image_idx , pick_idx = idx_batch
132- picked_vector = self .target_source [image_idx [:, None ], pick_idx ]
136+ picked_vector = self .target_source_val [image_idx [:, None ], pick_idx ]
133137
134138 origin_outputs = self (images , dict (target = picked_vector .permute (0 , 2 , 1 )))
135139 H , W = images .shape [2 :]
0 commit comments