@@ -562,13 +562,34 @@ def __init__(self, subset: str, data_root: str = "data",
562562 image_size : int = 96 ,
563563 take_num : Optional [int ] = None ,
564564 metadata : Optional [List [Metadata ]] = None ,
565+ pred_mode : bool = False
565566 ):
566567 self .subset = subset
567568 self .data_root = data_root
568569 self .image_size = image_size
570+ self .pred_mode = pred_mode
571+
569572 if metadata is None :
570- metadata_json = read_json (os .path .join (self .data_root , f"{ subset } _metadata.json" ))
571- self .metadata = [Metadata (** meta , fps = 25 ) for meta in metadata_json ]
573+ if self .pred_mode :
574+ with open (os .path .join (self .data_root , f"{ subset } _files.txt" ), "r" ) as f :
575+ files = [line .strip () for line in f .readlines () if line .strip () != "" ]
576+ self .metadata = [ # dummy metadata for prediction
577+ Metadata (file = file_name ,
578+ original = None ,
579+ split = subset ,
580+ fake_segments = [],
581+ fps = 25 ,
582+ visual_fake_segments = [],
583+ audio_fake_segments = [],
584+ audio_model = "" ,
585+ modify_type = "" ,
586+ video_frames = - 1 ,
587+ audio_frames = - 1 )
588+ for file_name in files
589+ ]
590+ else :
591+ metadata_json = read_json (os .path .join (self .data_root , f"{ subset } _metadata.json" ))
592+ self .metadata = [Metadata (** meta , fps = 25 ) for meta in metadata_json ]
572593 else :
573594 self .metadata = metadata
574595
@@ -584,5 +605,5 @@ def __getitem__(self, index):
584605 video , audio , _ = read_video (os .path .join (self .data_root , self .subset , meta .file ))
585606 if self .image_size != 224 :
586607 video = resize_video (video , (self .image_size , self .image_size ))
587- label = len (meta .fake_periods ) > 0
588- return video , audio , label
608+ label = len (meta .fake_periods ) > 0 if not self . pred_mode else False
609+ return video , audio , label
0 commit comments