Skip to content

Commit b4ce1b1

Browse files
committed
add xception example
1 parent d9d34de commit b4ce1b1

5 files changed

Lines changed: 263 additions & 0 deletions

File tree

examples/xception/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Xception
2+
3+
This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset for classification with video-level labels.
4+
5+
## Requirements
6+
7+
- Python
8+
- PyTorch
9+
- PyTorch Lightning
10+
- TIMM
11+
- AVDeepfake1M SDK
12+
13+
14+
## Training
15+
16+
```bash
17+
python train.py --data_root /path/to/avdeepfake1m --model xception
18+
```
19+
20+
## Output
21+
22+
* **Checkpoints:** Model checkpoints are saved under `./ckpt1/xception/`. The last checkpoint is saved as `last.ckpt`.
23+
* **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`).

examples/xception/train.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import argparse
2+
3+
from torch.utils.data import DataLoader
4+
from pytorch_lightning import Trainer
5+
from pytorch_lightning.callbacks import ModelCheckpoint
6+
from avdeepfake1m.loader import AVDeepfake1mPlusPlusImages
7+
8+
from xception import Xception
9+
from utils import LrLogger, EarlyStoppingLR
10+
11+
12+
parser = argparse.ArgumentParser(description="Classification model training")
13+
parser.add_argument("--data_root", type=str)
14+
parser.add_argument("--batch_size", type=int, default=128)
15+
parser.add_argument("--model", type=str, choices=["xception", "meso4", "meso_inception4"])
16+
parser.add_argument("--gpus", type=int, default=1)
17+
parser.add_argument("--precision", default=32)
18+
parser.add_argument("--num_train", type=int, default=None)
19+
parser.add_argument("--num_val", type=int, default=2000)
20+
parser.add_argument("--max_epochs", type=int, default=500)
21+
parser.add_argument("--resume", type=str, default=None)
22+
args = parser.parse_args()
23+
24+
25+
if __name__ == "__main__":
26+
27+
# You can fix the random seed if you want reproducible subsets each epoch:
28+
# torch.manual_seed(42)
29+
# random.seed(42)
30+
31+
learning_rate = 1e-4
32+
gpus = args.gpus
33+
total_batch_size = args.batch_size * gpus
34+
learning_rate = learning_rate * total_batch_size / 4
35+
36+
# Setup model
37+
if args.model == "xception":
38+
model = Xception(learning_rate, distributed=gpus > 1)
39+
else:
40+
raise ValueError(f"Unknown model: {args.model}")
41+
42+
train_dataset = AVDeepfake1mPlusPlusImages(
43+
subset="train",
44+
data_root=args.data_root,
45+
take_num=args.num_train,
46+
use_video_label=True # For video-level label access, set True
47+
)
48+
49+
# For validation, you can still do the normal dataset
50+
val_dataset = AVDeepfake1mPlusPlusImages(
51+
subset="val",
52+
data_root=args.data_root,
53+
take_num=args.num_val,
54+
use_video_label=True
55+
)
56+
57+
# Parse precision properly
58+
try:
59+
precision = int(args.precision)
60+
except ValueError:
61+
precision = args.precision
62+
63+
monitor = "val_loss"
64+
65+
trainer = Trainer(
66+
log_every_n_steps=50,
67+
precision=precision,
68+
max_epochs=args.max_epochs,
69+
callbacks=[
70+
ModelCheckpoint(
71+
dirpath=f"./ckpt1/{args.model}",
72+
save_last=True,
73+
filename=args.model + "-{epoch}-{val_loss:.3f}",
74+
monitor=monitor,
75+
mode="min"
76+
),
77+
LrLogger(),
78+
EarlyStoppingLR(lr_threshold=1e-7)
79+
],
80+
enable_checkpointing=True,
81+
benchmark=True,
82+
accelerator="gpu",
83+
devices=args.gpus,
84+
strategy="ddp" if args.gpus > 1 else "auto",
85+
# ckpt_path=args.resume,
86+
# If you're on an older version of Lightning, you may need `strategy='ddp'` just the same, but this is typical.
87+
)
88+
89+
trainer.fit(
90+
model,
91+
train_dataloaders=DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0),
92+
val_dataloaders=DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0),
93+
ckpt_path=args.resume,
94+
)

examples/xception/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import re
2+
3+
from pytorch_lightning import Callback, Trainer, LightningModule
4+
5+
6+
class LrLogger(Callback):
7+
"""Log learning rate in each epoch start."""
8+
9+
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
10+
for i, optimizer in enumerate(trainer.optimizers):
11+
for j, params in enumerate(optimizer.param_groups):
12+
key = f"opt{i}_lr{j}"
13+
value = params["lr"]
14+
pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
15+
pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
16+
17+
18+
class EarlyStoppingLR(Callback):
19+
"""Early stop model training when the LR is lower than threshold."""
20+
21+
def __init__(self, lr_threshold: float, mode="all"):
22+
self.lr_threshold = lr_threshold
23+
24+
if mode in ("any", "all"):
25+
self.mode = mode
26+
else:
27+
raise ValueError(f"mode must be one of ('any', 'all')")
28+
29+
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
30+
self._run_early_stop_checking(trainer)
31+
32+
def _run_early_stop_checking(self, trainer: Trainer) -> None:
33+
metrics = trainer._logger_connector.callback_metrics
34+
if len(metrics) == 0:
35+
return
36+
all_lr = []
37+
for key, value in metrics.items():
38+
if re.match(r"opt\d+_lr\d+", key):
39+
all_lr.append(value)
40+
41+
if len(all_lr) == 0:
42+
return
43+
44+
if self.mode == "all":
45+
if all(lr <= self.lr_threshold for lr in all_lr):
46+
trainer.should_stop = True
47+
elif self.mode == "any":
48+
if any(lr <= self.lr_threshold for lr in all_lr):
49+
trainer.should_stop = True

examples/xception/xception.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import timm
2+
3+
from pytorch_lightning import LightningModule
4+
from torch.nn import BCEWithLogitsLoss
5+
from torch.optim import Adam
6+
7+
8+
class Xception(LightningModule):
9+
def __init__(self, lr, distributed=False):
10+
super(Xception, self).__init__()
11+
self.lr = lr
12+
self.model = timm.create_model('xception', pretrained=True, num_classes=1)
13+
self.loss_fn = BCEWithLogitsLoss()
14+
self.distributed = distributed
15+
16+
def forward(self, x):
17+
x = self.model(x)
18+
return x
19+
20+
def training_step(self, batch, batch_idx):
21+
x, y = batch
22+
y_hat = self(x)
23+
loss = self.loss_fn(y_hat, y.unsqueeze(1))
24+
self.log('train_loss', loss)
25+
return loss
26+
27+
def validation_step(self, batch, batch_idx):
28+
x, y = batch
29+
y_hat = self(x)
30+
loss = self.loss_fn(y_hat, y.unsqueeze(1))
31+
self.log('val_loss', loss)
32+
return loss
33+
34+
def configure_optimizers(self):
35+
optimizer = Adam(self.parameters(), lr=self.lr)
36+
return [optimizer]

python/avdeepfake1m/loader.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,64 @@ def sample_indexes(total_frames: int, n_frames: int, temporal_sample_rate: int):
538538
print(f"total_frames: {total_frames}, n_frames: {n_frames}, temporal_sample_rate: {temporal_sample_rate}")
539539
raise e
540540
return torch.arange(n_frames) * temporal_sample_rate + start_ind
541+
542+
543+
class AVDeepfake1mPlusPlusImages(IterableDataset):
544+
545+
def __init__(self, subset: str, data_root: str = "data",
546+
image_size: int = 96,
547+
use_video_label: bool = False,
548+
use_seg_label: Optional[int] = None,
549+
take_num: Optional[int] = None,
550+
metadata: Optional[List[Metadata]] = None,
551+
):
552+
self.subset = subset
553+
self.data_root = data_root
554+
self.image_size = image_size
555+
self.use_video_label = use_video_label
556+
if self.use_video_label:
557+
assert use_seg_label is None
558+
self.use_seg_label = use_seg_label
559+
if metadata is None:
560+
metadata_json = read_json(os.path.join(self.data_root, f"{subset}_metadata.json"))
561+
self.metadata = [Metadata(**meta, fps=25) for meta in metadata_json]
562+
else:
563+
self.metadata = metadata
564+
565+
if take_num is not None:
566+
self.metadata = self.metadata[:take_num]
567+
568+
self.total_frames = sum([each.video_frames for each in self.metadata])
569+
print("Load {} data in {}.".format(len(self.metadata), subset))
570+
571+
def __len__(self):
572+
return self.total_frames
573+
574+
def __iter__(self):
575+
for meta in self.metadata:
576+
video = read_video_fast(os.path.join(self.data_root,self.subset, meta.file))
577+
if self.image_size != 224:
578+
video = resize_video(video, (96, 96))
579+
if self.use_video_label:
580+
label = float(len(meta.fake_periods) > 0)
581+
for frame in video:
582+
yield frame, label
583+
elif self.use_seg_label:
584+
frame_label = torch.zeros(len(video))
585+
for begin, end in meta.fake_periods:
586+
begin = int(begin * 25)
587+
end = int(end * 25)
588+
frame_label[begin: end] = 1
589+
seg_label = torch.split(frame_label, self.use_seg_label)
590+
seg_label = torch.nn.utils.rnn.pad_sequence(seg_label, batch_first=True)
591+
seg_label = (seg_label.sum(dim=1) > 0).float().repeat_interleave(self.use_seg_label)
592+
for i, frame in enumerate(video):
593+
yield frame, seg_label[i]
594+
else:
595+
frame_label = torch.zeros(len(video))
596+
for begin, end in meta.fake_periods:
597+
begin = int(begin * 25)
598+
end = int(end * 25)
599+
frame_label[begin: end] = 1
600+
for i, frame in enumerate(video):
601+
yield frame, frame_label[i]

0 commit comments

Comments
 (0)