Skip to content

Commit 36aa498

Browse files
committed
update the example code of BA-TFD and BA-TFD+
1 parent 5d11d51 commit 36aa498

26 files changed

Lines changed: 2298 additions & 85 deletions

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,4 +441,9 @@ docs/_build/
441441
# Pyenv
442442
.python-version
443443

444-
lightning_logs
444+
/examples/batfd/lightning_logs
445+
/examples/batfd/ckpt
446+
/examples/batfd/output
447+
/examples/xception/lightning_logs
448+
/examples/xception/ckpt
449+
/examples/xception/output

examples/batfd/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# BA-TFD
2+
3+
This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.
4+
## Requirements
5+
6+
Ensure you have the necessary environment setup. You can create a Conda environment using the following commands:
7+
8+
```bash
9+
# prepare the environment
10+
conda create -n batfd python=3.10 -y
11+
conda activate batfd
12+
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
13+
pip install avdeepfake1m toml tensorboard pytorch-lightning pandas
14+
```
15+
16+
## Training
17+
18+
Train the BATFD or BATFD+ model using a TOML configuration file (e.g., `batfd.toml` or `batfd_plus.toml`).
19+
20+
```bash
21+
python train.py --config ./batfd.toml --data_root /path/to/AV-Deepfake1M-PlusPlus
22+
```
23+
24+
### Output
25+
26+
* **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`.
27+
* **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`).
28+
29+
## Inference
30+
31+
After training, generate predictions on a dataset subset (e.g., `val`, `test`) using `infer.py`. This script saves the predictions to a JSON file, which is required for evaluation.
32+
33+
```bash
34+
python infer.py --config ./batfd.toml --checkpoint /path/to/checkpoint --data_root /path/to/AV-Deepfake1M-PlusPlus --subset val
35+
```
36+
37+
## Evaluation
38+
39+
```bash
40+
python evaluate.py /path/to/prediction_file /path/to/metadata_file
41+
```
42+

examples/batfd/batfd.toml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name = "batfd"
2+
num_frames = 100 # T
3+
max_duration = 30 # D
4+
model_type = "batfd"
5+
dataset = "avdeepfake1m++"
6+
7+
[model.video_encoder]
8+
type = "c3d"
9+
hidden_dims = [64, 96, 128, 128]
10+
cla_feature_in = 256 # C_f
11+
12+
[model.audio_encoder]
13+
type = "cnn"
14+
hidden_dims = [32, 64, 64]
15+
cla_feature_in = 256 # C_f
16+
17+
[model.frame_classifier]
18+
type = "lr"
19+
20+
[model.boundary_module]
21+
hidden_dims = [512, 128]
22+
samples = 10 # N
23+
24+
[optimizer]
25+
learning_rate = 0.00001
26+
frame_loss_weight = 2.0
27+
modal_bm_loss_weight = 1.0
28+
contrastive_loss_weight = 0.1
29+
contrastive_loss_margin = 0.99
30+
weight_decay = 0.0001
31+
32+
[soft_nms]
33+
alpha = 0.7234
34+
t1 = 0.1968
35+
t2 = 0.4123

examples/batfd/batfd/__init__.py

Whitespace-only changes.

examples/batfd/batfd/inference.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os.path
2+
from typing import Any, List, Optional
3+
from torch import Tensor
4+
import pandas as pd
5+
from pathlib import Path
6+
from lightning.pytorch import LightningModule, Trainer, Callback
7+
8+
from avdeepfake1m.loader import Metadata
9+
from torch.utils.data import DataLoader
10+
11+
12+
def nullable_index(obj, index):
13+
if obj is None:
14+
return None
15+
return obj[index]
16+
17+
18+
class SaveToCsvCallback(Callback):
19+
20+
def __init__(self, max_duration: int, metadata: List[Metadata], model_name: str, model_type: str, temp_dir: str):
21+
super().__init__()
22+
self.max_duration = max_duration
23+
self.metadata = metadata
24+
self.model_name = model_name
25+
self.model_type = model_type
26+
self.temp_dir = temp_dir
27+
28+
def on_predict_batch_end(
29+
self,
30+
trainer: Trainer,
31+
pl_module: LightningModule,
32+
outputs: Any,
33+
batch: Any,
34+
batch_idx: int,
35+
) -> None:
36+
if self.model_type == "batfd":
37+
fusion_bm_map, v_bm_map, a_bm_map, v_frame_cla, a_frame_cla = outputs
38+
batch_size = fusion_bm_map.shape[0]
39+
40+
for i in range(batch_size):
41+
temporal_size = batch[3][i]
42+
video_name = self.metadata[batch_idx * batch_size + i].file
43+
n_frames = self.metadata[batch_idx * batch_size + i].video_frames
44+
45+
assert isinstance(video_name, str)
46+
self.gen_df_for_batfd(fusion_bm_map[i], temporal_size, n_frames, os.path.join(
47+
self.temp_dir, self.model_name, video_name.replace("/", "_").replace(".mp4", ".csv")
48+
))
49+
50+
elif self.model_type == "batfd_plus":
51+
fusion_bm_map, fusion_start, fusion_end, v_bm_map, v_start, v_end, a_bm_map, a_start, a_end, v_frame_cla, a_frame_cla = outputs
52+
batch_size = fusion_bm_map.shape[0]
53+
54+
for i in range(batch_size):
55+
temporal_size = batch[3][i]
56+
video_name = self.metadata[batch_idx * batch_size + i].file
57+
n_frames = self.metadata[batch_idx * batch_size + i].video_frames
58+
assert isinstance(video_name, str)
59+
60+
self.gen_df_for_batfd_plus(fusion_bm_map[i], nullable_index(fusion_start, i),
61+
nullable_index(fusion_end, i), temporal_size,
62+
n_frames, os.path.join(self.temp_dir, self.model_name,
63+
video_name.replace("/", "_").replace(".mp4", ".csv")
64+
))
65+
66+
else:
67+
raise ValueError("Invalid model type")
68+
69+
def gen_df_for_batfd(self, bm_map: Tensor, temporal_size: Tensor, n_frames: int, output_file: str):
70+
bm_map = bm_map.cpu().numpy()
71+
temporal_size = temporal_size.cpu().numpy().item()
72+
# for each boundary proposal in boundary map
73+
df = pd.DataFrame(bm_map)
74+
df = df.stack().reset_index()
75+
df.columns = ["duration", "begin", "score"]
76+
df["end"] = df.duration + df.begin
77+
df = df[(df.duration > 0) & (df.end <= temporal_size)]
78+
df = df.sort_values(["begin", "end"])
79+
df = df.reset_index()[["begin", "end", "score"]]
80+
df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int)
81+
df["end"] = (df["end"] / temporal_size * n_frames).astype(int)
82+
df = df.sort_values(["score"], ascending=False).iloc[:100]
83+
df.to_csv(output_file, index=False)
84+
85+
def gen_df_for_batfd_plus(self, bm_map: Tensor, start: Optional[Tensor], end: Optional[Tensor],
86+
temporal_size: Tensor, n_frames: int, output_file: str
87+
):
88+
bm_map = bm_map.cpu().numpy()
89+
temporal_size = temporal_size.cpu().numpy().item()
90+
if start is not None and end is not None:
91+
start = start.cpu().numpy()
92+
end = end.cpu().numpy()
93+
94+
# for each boundary proposal in boundary map
95+
df = pd.DataFrame(bm_map)
96+
df = df.stack().reset_index()
97+
df.columns = ["duration", "begin", "score"]
98+
df["end"] = df.duration + df.begin
99+
df = df[(df.duration > 0) & (df.end <= temporal_size)]
100+
df = df.sort_values(["begin", "end"])
101+
df = df.reset_index()[["begin", "end", "score"]]
102+
if start is not None and end is not None:
103+
df["score"] = df["score"] * start[df.begin] * end[df.end]
104+
105+
df["begin"] = (df["begin"] / temporal_size * n_frames).astype(int)
106+
df["end"] = (df["end"] / temporal_size * n_frames).astype(int)
107+
df = df.sort_values(["score"], ascending=False).iloc[:100]
108+
df.to_csv(output_file, index=False)
109+
110+
111+
def inference_model(model_name: str, model: LightningModule, dataloader: DataLoader,
112+
metadata: List[Metadata],
113+
max_duration: int, model_type: str,
114+
gpus: int = 1,
115+
temp_dir: str = "output/",
116+
subset: str = "test"
117+
) -> List[Metadata]:
118+
Path(os.path.join(temp_dir, model_name)).mkdir(parents=True, exist_ok=True)
119+
assert subset in ["test", "val"]
120+
121+
model.eval()
122+
123+
trainer = Trainer(logger=False,
124+
enable_checkpointing=False, devices=1 if gpus > 1 else "auto",
125+
accelerator="auto" if gpus > 0 else "cpu",
126+
callbacks=[SaveToCsvCallback(max_duration, metadata, model_name, model_type, temp_dir)]
127+
)
128+
129+
trainer.predict(model, dataloader)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .batfd import Batfd
2+
from .batfd_plus import BatfdPlus
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from typing import Literal
2+
3+
from einops import rearrange
4+
from einops.layers.torch import Rearrange
5+
from torch import Tensor
6+
from torch.nn import Module, Sequential, LeakyReLU, MaxPool2d, Linear
7+
from torchvision.models.vision_transformer import Encoder as ViTEncoder
8+
9+
from ..utils import Conv2d
10+
11+
12+
class CNNAudioEncoder(Module):
13+
"""
14+
Audio encoder (E_a): Process log mel spectrogram to extract features.
15+
Input:
16+
A': (B, F_m, T_a)
17+
Output:
18+
E_a: (B, C_f, T)
19+
"""
20+
21+
def __init__(self, n_features=(32, 64, 64)):
22+
super().__init__()
23+
24+
n_dim0, n_dim1, n_dim2 = n_features
25+
26+
# (B, 64, 2048) -> (B, 1, 64, 2048) -> (B, 32, 32, 1024)
27+
self.block0 = Sequential(
28+
Rearrange("b c t -> b 1 c t"),
29+
Conv2d(1, n_dim0, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
30+
MaxPool2d(2)
31+
)
32+
33+
# (B, 32, 32, 1024) -> (B, 64, 16, 512)
34+
self.block1 = Sequential(
35+
Conv2d(n_dim0, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
36+
Conv2d(n_dim1, n_dim1, kernel_size=3, stride=1, padding=1, build_activation=LeakyReLU),
37+
MaxPool2d(2)
38+
)
39+
40+
# (B, 64, 16, 512) -> (B, 64, 4, 512) -> (B, 256, 512)
41+
self.block2 = Sequential(
42+
Conv2d(n_dim1, n_dim2, kernel_size=(2, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU),
43+
MaxPool2d((2, 1)),
44+
Conv2d(n_dim2, n_dim2, kernel_size=(3, 1), stride=1, padding=(1, 0), build_activation=LeakyReLU),
45+
MaxPool2d((2, 1)),
46+
Rearrange("b f c t -> b (f c) t")
47+
)
48+
49+
def forward(self, audio: Tensor) -> Tensor:
50+
x = self.block0(audio)
51+
x = self.block1(x)
52+
x = self.block2(x)
53+
return x
54+
55+
56+
class SelfAttentionAudioEncoder(Module):
57+
58+
def __init__(self, block_type: Literal["vit_t", "vit_s", "vit_b"], a_cla_feature_in: int = 256, temporal_size: int = 512):
59+
super().__init__()
60+
# The ViT configurations are from:
61+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
62+
if block_type == "vit_t":
63+
self.n_features = 192
64+
self.block = ViTEncoder(
65+
seq_length=temporal_size,
66+
num_layers=12,
67+
num_heads=3,
68+
hidden_dim=self.n_features,
69+
mlp_dim=self.n_features * 4,
70+
dropout=0.,
71+
attention_dropout=0.
72+
)
73+
elif block_type == "vit_s":
74+
self.n_features = 384
75+
self.block = ViTEncoder(
76+
seq_length=temporal_size,
77+
num_layers=12,
78+
num_heads=6,
79+
hidden_dim=self.n_features,
80+
mlp_dim=self.n_features * 4,
81+
dropout=0.,
82+
attention_dropout=0.
83+
)
84+
elif block_type == "vit_b":
85+
self.n_features = 768
86+
self.block = ViTEncoder(
87+
seq_length=temporal_size,
88+
num_layers=12,
89+
num_heads=12,
90+
hidden_dim=self.n_features,
91+
mlp_dim=self.n_features * 4,
92+
dropout=0.,
93+
attention_dropout=0.
94+
)
95+
else:
96+
raise ValueError(f"Unknown block type: {block_type}")
97+
98+
self.input_proj = Conv2d(1, self.n_features, kernel_size=(64, 4), stride=(64, 4))
99+
self.output_proj = Linear(self.n_features, a_cla_feature_in)
100+
101+
def forward(self, audio: Tensor) -> Tensor:
102+
x = audio.unsqueeze(1) # (B, 64, 2048) -> (B, 1, 64, 2048)
103+
x = self.input_proj(x) # (B, 1, 64, 2048) -> (B, feat, 1, 512)
104+
x = rearrange(x, "b f 1 t -> b t f") # (B, feat, 1, 512) -> (B, 512, feat)
105+
x = self.block(x)
106+
x = self.output_proj(x) # (B, 512, feat) -> (B, 512, 256)
107+
x = x.permute(0, 2, 1) # (B, 512, 256) -> (B, 256, 512)
108+
return x
109+
110+
111+
class AudioFeatureProjection(Module):
112+
113+
def __init__(self, input_feature_dim: int, a_cla_feature_in: int = 256):
114+
super().__init__()
115+
self.proj = Linear(input_feature_dim, a_cla_feature_in)
116+
117+
def forward(self, x: Tensor) -> Tensor:
118+
x = self.proj(x)
119+
return x.permute(0, 2, 1)
120+
121+
122+
def get_audio_encoder(a_cla_feature_in, temporal_size, a_encoder, ae_features):
123+
if a_encoder == "cnn":
124+
audio_encoder = CNNAudioEncoder(n_features=ae_features)
125+
elif a_encoder == "vit_t":
126+
audio_encoder = SelfAttentionAudioEncoder(block_type="vit_t", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
127+
elif a_encoder == "vit_s":
128+
audio_encoder = SelfAttentionAudioEncoder(block_type="vit_s", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
129+
elif a_encoder == "vit_b":
130+
audio_encoder = SelfAttentionAudioEncoder(block_type="vit_b", a_cla_feature_in=a_cla_feature_in, temporal_size=temporal_size)
131+
elif a_encoder == "wav2vec2":
132+
audio_encoder = AudioFeatureProjection(input_feature_dim=1536, a_cla_feature_in=a_cla_feature_in)
133+
elif a_encoder == "trillsson3":
134+
audio_encoder = AudioFeatureProjection(input_feature_dim=1280, a_cla_feature_in=a_cla_feature_in)
135+
else:
136+
raise ValueError(f"Invalid audio encoder: {a_encoder}")
137+
return audio_encoder

0 commit comments

Comments
 (0)