Skip to content

Commit 76863ca

Browse files
committed
impl xception infer for avdf1m++
1 parent 40d14f9 commit 76863ca

2 files changed

Lines changed: 30 additions & 8 deletions

File tree

examples/xception/infer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
parser.add_argument("--checkpoint", type=str)
1313
parser.add_argument("--model", type=str)
1414
parser.add_argument("--batch_size", type=int, default=128)
15-
parser.add_argument("--subset", type=str, choices=["train", "val", "test"])
15+
parser.add_argument("--subset", type=str, choices=["train", "val", "test", "testA", "testB"])
1616
parser.add_argument("--gpus", type=int, default=1)
1717
parser.add_argument("--take_num", type=int, default=None)
1818

@@ -27,14 +27,14 @@
2727
raise ValueError(f"Unknown model: {args.model}")
2828

2929
model.to(device)
30-
model.train()
31-
test_dataset = AVDeepfake1mPlusPlusVideo(args.subset, args.data_root, take_num=args.take_num)
30+
model.train() # not sure why but eval mode will generate nonsense output
31+
test_dataset = AVDeepfake1mPlusPlusVideo(args.subset, args.data_root, take_num=args.take_num, pred_mode=True)
3232

3333
save_path = f"output/{args.model}_{args.subset}.txt"
3434
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
3535
with open(save_path, "w") as f:
3636
with torch.inference_mode():
37-
for i, (video, _, label) in enumerate(tqdm(test_dataset)):
37+
for i, (video, _, _) in enumerate(tqdm(test_dataset)):
3838
# batch video as frames use batch_size
3939
preds_video = []
4040
for j in range(0, len(video), args.batch_size):
@@ -47,3 +47,4 @@
4747

4848
file_name = test_dataset.metadata[i].file
4949
f.write(f"{file_name};{pred}\n")
50+
f.flush()

python/avdeepfake1m/loader.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,34 @@ def __init__(self, subset: str, data_root: str = "data",
562562
image_size: int = 96,
563563
take_num: Optional[int] = None,
564564
metadata: Optional[List[Metadata]] = None,
565+
pred_mode: bool = False
565566
):
566567
self.subset = subset
567568
self.data_root = data_root
568569
self.image_size = image_size
570+
self.pred_mode = pred_mode
571+
569572
if metadata is None:
570-
metadata_json = read_json(os.path.join(self.data_root, f"{subset}_metadata.json"))
571-
self.metadata = [Metadata(**meta, fps=25) for meta in metadata_json]
573+
if self.pred_mode:
574+
with open(os.path.join(self.data_root, f"{subset}_files.txt"), "r") as f:
575+
files = [line.strip() for line in f.readlines() if line.strip() != ""]
576+
self.metadata = [ # dummy metadata for prediction
577+
Metadata(file=file_name,
578+
original=None,
579+
split=subset,
580+
fake_segments=[],
581+
fps=25,
582+
visual_fake_segments=[],
583+
audio_fake_segments=[],
584+
audio_model="",
585+
modify_type="",
586+
video_frames=-1,
587+
audio_frames=-1)
588+
for file_name in files
589+
]
590+
else:
591+
metadata_json = read_json(os.path.join(self.data_root, f"{subset}_metadata.json"))
592+
self.metadata = [Metadata(**meta, fps=25) for meta in metadata_json]
572593
else:
573594
self.metadata = metadata
574595

@@ -584,5 +605,5 @@ def __getitem__(self, index):
584605
video, audio, _ = read_video(os.path.join(self.data_root, self.subset, meta.file))
585606
if self.image_size != 224:
586607
video = resize_video(video, (self.image_size, self.image_size))
587-
label = len(meta.fake_periods) > 0
588-
return video, audio, label
608+
label = len(meta.fake_periods) > 0 if not self.pred_mode else False
609+
return video, audio, label

0 commit comments

Comments
 (0)