@@ -105,6 +105,7 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
105105 file = self .file_list [index ]
106106
107107 video , audio , _ = read_video (os .path .join (self .root , self .subset , file ))
108+ n_frames = video .shape [0 ]
108109 video = F .interpolate (video .float ().permute (1 , 0 , 2 , 3 )[None ], size = (self .temporal_size , 96 , 96 ))[0 ]
109110 audio = F .interpolate (audio .float ().permute (1 , 0 )[None ], size = self .audio_temporal_size , mode = "linear" )[0 ].permute (1 , 0 )
110111 video = self .video_transform (video )
@@ -113,7 +114,7 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
113114
114115 outputs = [video , audio ]
115116
116- if self .subset != "test" :
117+ if self .subset not in ( "test" , "testA" , "testB" ) :
117118 if self .is_plusplus :
118119 subset_folder = self .subset
119120 else :
@@ -131,6 +132,9 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
131132 if self .return_file_name :
132133 outputs .append (meta .file )
133134
135+ else :
136+ outputs = outputs + [n_frames ]
137+
134138 return outputs
135139
136140 def get_label (self , file : str , meta : Metadata ) -> tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
@@ -244,6 +248,7 @@ def __init__(self, root: str = "data", temporal_size: int = 100,
244248 get_meta_attr : Callable [[Metadata , Tensor , Tensor , Tensor ], List [Any ]] = _default_get_meta_attr ,
245249 return_file_name : bool = False ,
246250 is_plusplus : bool = False ,
251+ test_subset : Optional [str ] = None
247252 ):
248253 super ().__init__ ()
249254 self .root = root
@@ -260,11 +265,15 @@ def __init__(self, root: str = "data", temporal_size: int = 100,
260265 self .return_file_name = return_file_name
261266 self .is_plusplus = is_plusplus
262267 self .Dataset = AVDeepfake1m
268+ if test_subset is None :
269+ self .test_subset = "test" if not self .is_plusplus else "testA"
270+ else :
271+ self .test_subset = test_subset
263272
264273 def setup (self , stage : Optional [str ] = None ) -> None :
265274 train_file_list = [meta ["file" ] for meta in read_json (os .path .join (self .root , "train_metadata.json" ))]
266275 val_file_list = [meta ["file" ] for meta in read_json (os .path .join (self .root , "val_metadata.json" ))]
267- with open (os .path .join (self .root , "test_files .txt" ), "r" ) as f :
276+ with open (os .path .join (self .root , f" { self . test_subset } _files .txt" ), "r" ) as f :
268277 test_file_list = list (filter (lambda x : x != "" , f .read ().split ("\n " )))
269278
270279 if self .take_val is not None :
@@ -285,7 +294,7 @@ def setup(self, stage: Optional[str] = None) -> None:
285294 return_file_name = self .return_file_name ,
286295 is_plusplus = self .is_plusplus
287296 )
288- self .test_dataset = self .Dataset ("test" , self .root , self .temporal_size , self .max_duration , self .fps ,
297+ self .test_dataset = self .Dataset (self . test_subset , self .root , self .temporal_size , self .max_duration , self .fps ,
289298 file_list = test_file_list , get_meta_attr = self .get_meta_attr ,
290299 require_match_scores = self .require_match_scores ,
291300 return_file_name = self .return_file_name ,
0 commit comments