Skip to content

Commit 4cd5f8b

Browse files
committed
adapt to test set
1 parent 22840bb commit 4cd5f8b

3 files changed

Lines changed: 47 additions & 16 deletions

File tree

examples/batfd/batfd/inference.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os.path
22
from typing import Any, List, Optional
3+
import torch
34
from torch import Tensor
45
import pandas as pd
56
from pathlib import Path
67
from lightning.pytorch import LightningModule, Trainer, Callback
8+
from torch.utils.data import DataLoader
79

810
from avdeepfake1m.loader import Metadata
9-
from torch.utils.data import DataLoader
1011

1112

1213
def nullable_index(obj, index):
@@ -38,9 +39,12 @@ def on_predict_batch_end(
3839
batch_size = fusion_bm_map.shape[0]
3940

4041
for i in range(batch_size):
41-
temporal_size = batch[3][i]
42+
temporal_size = torch.tensor(100) # the first value of `Batfd.get_meta_attr`
4243
video_name = self.metadata[batch_idx * batch_size + i].file
4344
n_frames = self.metadata[batch_idx * batch_size + i].video_frames
45+
# if n_frames is not available, it should be in test set, and we can get it from the batch
46+
if n_frames == -1:
47+
n_frames = batch[-1][i].cpu().numpy().item()
4448

4549
assert isinstance(video_name, str)
4650
self.gen_df_for_batfd(fusion_bm_map[i], temporal_size, n_frames, os.path.join(
@@ -52,9 +56,12 @@ def on_predict_batch_end(
5256
batch_size = fusion_bm_map.shape[0]
5357

5458
for i in range(batch_size):
55-
temporal_size = batch[3][i]
59+
temporal_size = torch.tensor(100) # the first value of `BatfdPlus.get_meta_attr`
5660
video_name = self.metadata[batch_idx * batch_size + i].file
5761
n_frames = self.metadata[batch_idx * batch_size + i].video_frames
62+
# if n_frames is not available, it should be in test set, and we can get it from the batch
63+
if n_frames == -1:
64+
n_frames = batch[-1][i].cpu().numpy().item()
5865
assert isinstance(video_name, str)
5966

6067
self.gen_df_for_batfd_plus(fusion_bm_map[i], nullable_index(fusion_start, i),
@@ -112,11 +119,9 @@ def inference_model(model_name: str, model: LightningModule, dataloader: DataLoa
112119
metadata: List[Metadata],
113120
max_duration: int, model_type: str,
114121
gpus: int = 1,
115-
temp_dir: str = "output/",
116-
subset: str = "test"
122+
temp_dir: str = "output/"
117123
) -> List[Metadata]:
118124
Path(os.path.join(temp_dir, model_name)).mkdir(parents=True, exist_ok=True)
119-
assert subset in ["test", "val"]
120125

121126
model.eval()
122127

examples/batfd/infer.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def main():
2020
help="Root directory of the dataset.")
2121
parser.add_argument("--num_workers", type=int, default=8,
2222
help="Number of workers for data loading.")
23-
parser.add_argument("--subset", type=str, choices=["val", "test"],
23+
parser.add_argument("--subset", type=str, choices=["val", "test", "testA", "testB"],
2424
default="test", help="Dataset subset.")
2525
parser.add_argument("--gpus", type=int, default=1,
2626
help="Number of GPUs. Set to 0 for CPU.")
@@ -62,23 +62,41 @@ def main():
6262
num_workers=args.num_workers,
6363
get_meta_attr=model.get_meta_attr,
6464
return_file_name=True,
65-
is_plusplus=is_plusplus
65+
is_plusplus=is_plusplus,
66+
test_subset=args.subset if args.subset in ("test", "testA", "testB") else None
6667
)
6768
dm.setup()
6869

6970
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
7071
Path(temp_dir).mkdir(parents=True, exist_ok=True)
7172

72-
if args.subset == "test":
73+
if args.subset in ("test", "testA", "testB"):
7374
dataloader = dm.test_dataloader()
74-
metadata_path = os.path.join(dm.root, "test_metadata.json")
75+
metadata_path = os.path.join(dm.root, f"{args.subset}_metadata.json")
7576
elif args.subset == "val":
7677
dataloader = dm.val_dataloader()
7778
metadata_path = os.path.join(dm.root, "val_metadata.json")
7879
else:
7980
raise ValueError("Invalid subset")
8081

81-
metadata = [Metadata(**each, fps=25) for each in read_json(metadata_path)]
82+
if os.path.exists(metadata_path):
83+
metadata = [Metadata(**each, fps=25) for each in read_json(metadata_path)]
84+
else:
85+
metadata = [
86+
Metadata(file=file_name,
87+
original=None,
88+
split=args.subset,
89+
fake_segments=[],
90+
fps=25,
91+
visual_fake_segments=[],
92+
audio_fake_segments=[],
93+
audio_model="",
94+
modify_type="",
95+
# handle by the predictor in `inference_model`
96+
video_frames=-1,
97+
audio_frames=-1)
98+
for file_name in dataloader.dataset.file_list
99+
]
82100

83101
inference_model(
84102
model_name=config["name"],
@@ -88,8 +106,7 @@ def main():
88106
max_duration=config["max_duration"],
89107
model_type=config["model_type"],
90108
gpus=args.gpus,
91-
temp_dir=temp_dir,
92-
subset=args.subset
109+
temp_dir=temp_dir
93110
)
94111

95112
post_process(

python/avdeepfake1m/loader.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
105105
file = self.file_list[index]
106106

107107
video, audio, _ = read_video(os.path.join(self.root, self.subset, file))
108+
n_frames = video.shape[0]
108109
video = F.interpolate(video.float().permute(1, 0, 2, 3)[None], size=(self.temporal_size, 96, 96))[0]
109110
audio = F.interpolate(audio.float().permute(1, 0)[None], size=self.audio_temporal_size, mode="linear")[0].permute(1, 0)
110111
video = self.video_transform(video)
@@ -113,7 +114,7 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
113114

114115
outputs = [video, audio]
115116

116-
if self.subset != "test":
117+
if self.subset not in ("test", "testA", "testB"):
117118
if self.is_plusplus:
118119
subset_folder = self.subset
119120
else:
@@ -131,6 +132,9 @@ def __getitem__(self, index: int) -> List[Union[Tensor, str, int]]:
131132
if self.return_file_name:
132133
outputs.append(meta.file)
133134

135+
else:
136+
outputs = outputs + [n_frames]
137+
134138
return outputs
135139

136140
def get_label(self, file: str, meta: Metadata) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
@@ -244,6 +248,7 @@ def __init__(self, root: str = "data", temporal_size: int = 100,
244248
get_meta_attr: Callable[[Metadata, Tensor, Tensor, Tensor], List[Any]] = _default_get_meta_attr,
245249
return_file_name: bool = False,
246250
is_plusplus: bool = False,
251+
test_subset: Optional[str] = None
247252
):
248253
super().__init__()
249254
self.root = root
@@ -260,11 +265,15 @@ def __init__(self, root: str = "data", temporal_size: int = 100,
260265
self.return_file_name = return_file_name
261266
self.is_plusplus = is_plusplus
262267
self.Dataset = AVDeepfake1m
268+
if test_subset is None:
269+
self.test_subset = "test" if not self.is_plusplus else "testA"
270+
else:
271+
self.test_subset = test_subset
263272

264273
def setup(self, stage: Optional[str] = None) -> None:
265274
train_file_list = [meta["file"] for meta in read_json(os.path.join(self.root, "train_metadata.json"))]
266275
val_file_list = [meta["file"] for meta in read_json(os.path.join(self.root, "val_metadata.json"))]
267-
with open(os.path.join(self.root, "test_files.txt"), "r") as f:
276+
with open(os.path.join(self.root, f"{self.test_subset}_files.txt"), "r") as f:
268277
test_file_list = list(filter(lambda x: x != "", f.read().split("\n")))
269278

270279
if self.take_val is not None:
@@ -285,7 +294,7 @@ def setup(self, stage: Optional[str] = None) -> None:
285294
return_file_name=self.return_file_name,
286295
is_plusplus=self.is_plusplus
287296
)
288-
self.test_dataset = self.Dataset("test", self.root, self.temporal_size, self.max_duration, self.fps,
297+
self.test_dataset = self.Dataset(self.test_subset, self.root, self.temporal_size, self.max_duration, self.fps,
289298
file_list=test_file_list, get_meta_attr=self.get_meta_attr,
290299
require_match_scores=self.require_match_scores,
291300
return_file_name=self.return_file_name,

0 commit comments

Comments
 (0)