-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathinfer.py
More file actions
128 lines (110 loc) · 4.33 KB
/
infer.py
File metadata and controls
128 lines (110 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import toml
import torch
import os
from pathlib import Path
from avdeepfake1m.loader import AVDeepfake1mDataModule, Metadata
from batfd.model import Batfd, BatfdPlus
from batfd.inference import inference_model
from batfd.post_process import post_process
from avdeepfake1m.utils import read_json
def main():
parser = argparse.ArgumentParser(description="BATFD/BATFD+ Inference")
parser.add_argument("--config", type=str, required=True,
help="Path to the TOML configuration file.")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to the model checkpoint.")
parser.add_argument("--data_root", type=str, required=True,
help="Root directory of the dataset.")
parser.add_argument("--num_workers", type=int, default=8,
help="Number of workers for data loading.")
parser.add_argument("--subset", type=str, choices=["val", "test", "testA", "testB"],
default="test", help="Dataset subset.")
parser.add_argument("--gpus", type=int, default=1,
help="Number of GPUs. Set to 0 for CPU.")
args = parser.parse_args()
# Determine device
if args.gpus > 0 and torch.cuda.is_available():
device = f"cuda:{torch.cuda.current_device()}"
else:
device = "cpu"
print(f"Using device: {device}")
config = toml.load(args.config)
temp_dir = "output"
output_file = os.path.join(temp_dir, f"{config['name']}_{args.subset}.json")
model_type = config["model_type"]
if model_type == "batfd_plus":
model = BatfdPlus.load_from_checkpoint(args.checkpoint)
elif model_type == "batfd":
model = Batfd.load_from_checkpoint(args.checkpoint)
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
# Setup DataModule
dm_dataset_name = config["dataset"]
is_plusplus = dm_dataset_name == "avdeepfake1m++"
dm = AVDeepfake1mDataModule(
root=args.data_root,
temporal_size=config["num_frames"],
max_duration=config["max_duration"],
require_match_scores=False,
batch_size=1, # due to the problem from lightning, only 1 is supported
num_workers=args.num_workers,
get_meta_attr=model.get_meta_attr,
return_file_name=True,
is_plusplus=is_plusplus,
test_subset=args.subset if args.subset in ("test", "testA", "testB") else None
)
dm.setup()
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
Path(temp_dir).mkdir(parents=True, exist_ok=True)
if args.subset in ("test", "testA", "testB"):
dataloader = dm.test_dataloader()
metadata_path = os.path.join(dm.root, f"{args.subset}_metadata.json")
elif args.subset == "val":
dataloader = dm.val_dataloader()
metadata_path = os.path.join(dm.root, "val_metadata.json")
else:
raise ValueError("Invalid subset")
if os.path.exists(metadata_path):
metadata = [Metadata(**each, fps=25) for each in read_json(metadata_path)]
else:
metadata = [
Metadata(file=file_name,
original=None,
split=args.subset,
fake_segments=[],
fps=25,
visual_fake_segments=[],
audio_fake_segments=[],
audio_model="",
modify_type="",
# handle by the predictor in `inference_model`
video_frames=-1,
audio_frames=-1)
for file_name in dataloader.dataset.file_list
]
inference_model(
model_name=config["name"],
model=model,
dataloader=dataloader,
metadata=metadata,
max_duration=config["max_duration"],
model_type=config["model_type"],
gpus=args.gpus,
temp_dir=temp_dir
)
post_process(
model_name=config["name"],
save_path=output_file,
metadata=metadata,
fps=25,
alpha=config["soft_nms"]["alpha"],
t1=config["soft_nms"]["t1"],
t2=config["soft_nms"]["t2"],
dataset_name=dm_dataset_name,
output_dir=temp_dir
)
print(f"Inference complete. Results saved to {output_file}")
if __name__ == '__main__':
main()