Skip to content

Commit 5adee64

Browse files
update target absent code and clean
1 parent 4b7d982 commit 5adee64

2 files changed

Lines changed: 3 additions & 20 deletions

File tree

src/data/components/cocosearch18/cocosearch18_ta.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,8 @@ def __init__(self, name: str, root_path: str, task: str, split: str, num_subject
1818
if split == 'valid':
1919
split = 'validation'
2020

21-
if split in ['train', 'validation']:
22-
with open(Path(self.root_path, f'coco_search18_fixations_TA_trainval.json'), 'rb') as f: # here each scanpath starts from the second fixation. The first is discarded
23-
self.samples = json.load(f)
24-
25-
# only consider the right split
26-
if split == 'train':
27-
self.samples = [s for s in self.samples if s['split'] == 'train']
28-
else:
29-
self.samples = [s for s in self.samples if s['split'] == 'validation']
30-
31-
else: # test split
32-
with open(Path(self.root_path, f'coco_search18_fixations_TA_test.json'), 'rb') as f: # here each scanpath starts from the second fixation. The first is discarded
33-
self.samples = json.load(f)
21+
with open(Path(self.root_path, f'coco_search18_fixations_TA_{split}.json'), 'rb') as f:
22+
self.samples = json.load(f)
3423

3524
self.task_embeddings = np.load(
3625
open(
@@ -95,12 +84,8 @@ def __getitem__(self, index: int):
9584
scanpath[:,0] /= 1680
9685
scanpath[:,1] /= 1050
9786

98-
#durations = np.hstack((sample['arrival_times'], sample['t_end']))[1:] - sample['arrival_times']
99-
#durations = np.reshape(durations, (-1, 1))
100-
#scanpath = np.hstack((coords, durations))
10187
scanpath = torch.from_numpy(scanpath).float() # gt duration is in seconds
10288

103-
10489
if not self.time_in_ms: # by default time is in milliseconds in the annotations
10590
scanpath[:,2] /= 1000.0
10691

src/data/components/cocosearch18/cocosearch18_tp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def __getitem__(self, index: int):
8585
scanpath[:,0] /= 512
8686
scanpath[:,1] /= 320
8787

88-
#durations = np.hstack((sample['arrival_times'], sample['t_end']))[1:] - sample['arrival_times']
89-
#durations = np.reshape(durations, (-1, 1))
90-
#scanpath = np.hstack((coords, durations))
88+
9189
scanpath = torch.from_numpy(scanpath).float() # gt duration is in seconds
9290

9391
if not self.time_in_ms: # by default time is in milliseconds in the annotations

0 commit comments

Comments
 (0)