Skip to content

Commit 0348eef

Browse files
committed
init
0 parents  commit 0348eef

1 file changed

Lines changed: 256 additions & 0 deletions

File tree

README.md

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# MARLIN: Masked Autoencoder for facial video Representation LearnINg
2+
3+
<div>
4+
<img src="assets/teaser.svg">
5+
<p></p>
6+
</div>
7+
8+
<div align="center">
9+
<a href="https://github.com/ControlNet/MARLIN/network/members">
10+
<img src="https://img.shields.io/github/forks/ControlNet/MARLIN?style=flat-square">
11+
</a>
12+
<a href="https://github.com/ControlNet/MARLIN/stargazers">
13+
<img src="https://img.shields.io/github/stars/ControlNet/MARLIN?style=flat-square">
14+
</a>
15+
<a href="https://github.com/ControlNet/MARLIN/issues">
16+
<img src="https://img.shields.io/github/issues/ControlNet/MARLIN?style=flat-square">
17+
</a>
18+
<a href="https://github.com/ControlNet/MARLIN/blob/master/LICENSE">
19+
<img src="https://img.shields.io/badge/license-CC--BY--NC%204.0-97ca00?style=flat-square">
20+
</a>
21+
<a href="https://arxiv.org/abs/2211.06627">
22+
<img src="https://img.shields.io/badge/arXiv-2211.06627-b31b1b.svg?style=flat-square">
23+
</a>
24+
</div>
25+
26+
<div align="center">
27+
<a href="https://pypi.org/project/marlin-pytorch/">
28+
<img src="https://img.shields.io/pypi/v/marlin-pytorch?style=flat-square">
29+
</a>
30+
<a href="https://pypi.org/project/marlin-pytorch/">
31+
<img src="https://img.shields.io/pypi/dm/marlin-pytorch?style=flat-square">
32+
</a>
33+
<a href="https://www.python.org/"><img src="https://img.shields.io/pypi/pyversions/marlin-pytorch?style=flat-square"></a>
34+
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-%3E%3D1.8.0-EE4C2C?style=flat-square&logo=pytorch"></a>
35+
</div>
36+
37+
<div align="center">
38+
<a href="https://github.com/ControlNet/MARLIN/actions"><img src="https://img.shields.io/github/actions/workflow/status/ControlNet/MARLIN/unittest.yaml?branch=dev&label=unittest&style=flat-square"></a>
39+
<a href="https://github.com/ControlNet/MARLIN/actions"><img src="https://img.shields.io/github/actions/workflow/status/ControlNet/MARLIN/release.yaml?branch=master&label=release&style=flat-square"></a>
40+
<a href="https://coveralls.io/github/ControlNet/MARLIN"><img src="https://img.shields.io/coverallsCoverage/github/ControlNet/MARLIN?style=flat-square"></a>
41+
</div>
42+
43+
This repo is the official PyTorch implementation for the paper
44+
[MARLIN: Masked Autoencoder for facial video Representation LearnINg](https://openaccess.thecvf.com/content/CVPR2023/html/Cai_MARLIN_Masked_Autoencoder_for_Facial_Video_Representation_LearnINg_CVPR_2023_paper) (CVPR 2023).
45+
46+
## Repository Structure
47+
48+
The repository contains 2 parts:
49+
- `marlin-pytorch`: The PyPI package for MARLIN used for inference.
50+
- The implementation for the paper including training and evaluation scripts.
51+
52+
```
53+
.
54+
├── assets # Images for README.md
55+
├── LICENSE
56+
├── README.md
57+
├── MODEL_ZOO.md
58+
├── CITATION.cff
59+
├── .gitignore
60+
├── .github
61+
62+
# below is for the PyPI package marlin-pytorch
63+
├── src # Source code for marlin-pytorch
64+
├── tests # Unittest
65+
├── requirements.lib.txt
66+
├── setup.py
67+
├── init.py
68+
├── version.txt
69+
70+
# below is for the paper implementation
71+
├── configs # Configs for experiments settings
72+
├── model # Marlin models
73+
├── preprocess # Preprocessing scripts
74+
├── dataset # Dataloaders
75+
├── utils # Utility functions
76+
├── train.py # Training script
77+
├── evaluate.py # Evaluation script (TODO)
78+
├── requirements.txt
79+
80+
```
81+
82+
## Use `marlin-pytorch` for Feature Extraction
83+
84+
Requirements:
85+
- Python >= 3.6, < 3.11
86+
- PyTorch >= 1.8
87+
- ffmpeg
88+
89+
90+
Install from PyPI:
91+
```bash
92+
pip install marlin-pytorch
93+
```
94+
95+
Load MARLIN model from online
96+
```python
97+
from marlin_pytorch import Marlin
98+
# Load MARLIN model from GitHub Release
99+
model = Marlin.from_online("marlin_vit_base_ytf")
100+
```
101+
102+
Load MARLIN model from file
103+
```python
104+
from marlin_pytorch import Marlin
105+
# Load MARLIN model from local file
106+
model = Marlin.from_file("marlin_vit_base_ytf", "path/to/marlin.pt")
107+
# Load MARLIN model from the ckpt file trained by the scripts in this repo
108+
model = Marlin.from_file("marlin_vit_base_ytf", "path/to/marlin.ckpt")
109+
```
110+
111+
Current model name list:
112+
- `marlin_vit_small_ytf`: ViT-small encoder trained on YTF dataset. Embedding 384 dim.
113+
- `marlin_vit_base_ytf`: ViT-base encoder trained on YTF dataset. Embedding 768 dim.
114+
- `marlin_vit_large_ytf`: ViT-large encoder trained on YTF dataset. Embedding 1024 dim.
115+
116+
For more details, see [MODEL_ZOO.md](MODEL_ZOO.md).
117+
118+
When MARLIN model is retrieved from GitHub Release, it will be cached in `.marlin`. You can remove marlin cache by
119+
```python
120+
from marlin_pytorch import Marlin
121+
Marlin.clean_cache()
122+
```
123+
124+
Extract features from cropped video file
125+
```python
126+
# Extract features from facial cropped video with size (224x224)
127+
features = model.extract_video("path/to/video.mp4")
128+
print(features.shape) # torch.Size([T, 768]) where T is the number of windows
129+
130+
# You can keep output of all elements from the sequence by setting keep_seq=True
131+
features = model.extract_video("path/to/video.mp4", keep_seq=True)
132+
print(features.shape) # torch.Size([T, k, 768]) where k = T/t * H/h * W/w = 8 * 14 * 14 = 1568
133+
```
134+
135+
Extract features from in-the-wild video file
136+
```python
137+
# Extract features from in-the-wild video with various size
138+
features = model.extract_video("path/to/video.mp4", crop_face=True)
139+
print(features.shape) # torch.Size([T, 768])
140+
```
141+
142+
Extract features from video clip tensor
143+
```python
144+
# Extract features from clip tensor with size (B, 3, 16, 224, 224)
145+
x = ... # video clip
146+
features = model.extract_features(x) # torch.Size([B, k, 768])
147+
features = model.extract_features(x, keep_seq=False) # torch.Size([B, 768])
148+
```
149+
150+
## Paper Implementation
151+
152+
### Requirements
153+
- Python >= 3.7, < 3.11
154+
- PyTorch ~= 1.11
155+
- Torchvision ~= 0.12
156+
157+
### Installation
158+
159+
Firstly, make sure you have installed PyTorch and Torchvision with or without CUDA.
160+
161+
Clone the repo and install the requirements:
162+
```bash
163+
git clone https://github.com/ControlNet/MARLIN.git
164+
cd MARLIN
165+
pip install -r requirements.txt
166+
```
167+
168+
### MARLIN Pretraining
169+
170+
Download the [YoutubeFaces](https://www.cs.tau.ac.il/~wolf/ytfaces/) dataset (only `frame_images_DB` is required).
171+
172+
Download the face parsing model from [face_parsing.farl.lapa](https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt)
173+
and put it in `utils/face_sdk/models/face_parsing/face_parsing_1.0`.
174+
175+
Download the VideoMAE pretrained [checkpoint](https://github.com/ControlNet/MARLIN/releases/misc)
176+
for initializing the weights. (ps. They updated their models in this
177+
[commit](https://github.com/MCG-NJU/VideoMAE/commit/2b56a75d166c619f71019e3d1bb1c4aedafe7a90), but we are using the
178+
old models which are not shared anymore by the authors. So we uploaded this model by ourselves.)
179+
180+
Then run scripts to process the dataset:
181+
```bash
182+
python preprocess/ytf_preprocess.py --data_dir /path/to/youtube_faces --max_workers 8
183+
```
184+
After processing, the directory structure should be like this:
185+
```
186+
├── YoutubeFaces
187+
│ ├── frame_images_DB
188+
│ │ ├── Aaron_Eckhart
189+
│ │ │ ├── 0
190+
│ │ │ │ ├── 0.555.jpg
191+
│ │ │ │ ├── ...
192+
│ │ │ ├── ...
193+
│ │ ├── ...
194+
│ ├── crop_images_DB
195+
│ │ ├── Aaron_Eckhart
196+
│ │ │ ├── 0
197+
│ │ │ │ ├── 0.555.jpg
198+
│ │ │ │ ├── ...
199+
│ │ │ ├── ...
200+
│ │ ├── ...
201+
│ ├── face_parsing_images_DB
202+
│ │ ├── Aaron_Eckhart
203+
│ │ │ ├── 0
204+
│ │ │ │ ├── 0.555.npy
205+
│ │ │ │ ├── ...
206+
│ │ │ ├── ...
207+
│ │ ├── ...
208+
│ ├── train_set.csv
209+
│ ├── val_set.csv
210+
```
211+
212+
Then, run the training script:
213+
```bash
214+
python train.py \
215+
--config config/pretrain/marlin_vit_base.yaml \
216+
--data_dir /path/to/youtube_faces \
217+
--n_gpus 4 \
218+
--num_workers 8 \
219+
--batch_size 16 \
220+
--epochs 2000 \
221+
--official_pretrained /path/to/videomae/checkpoint.pth
222+
```
223+
224+
After trained, you can load the checkpoint for inference by
225+
226+
```python
227+
from marlin_pytorch import Marlin
228+
from marlin_pytorch.config import register_model_from_yaml
229+
230+
register_model_from_yaml("my_marlin_model", "path/to/config.yaml")
231+
model = Marlin.from_file("my_marlin_model", "path/to/marlin.ckpt")
232+
```
233+
234+
## References
235+
If you find this work useful for your research, please consider citing it.
236+
```bibtex
237+
@inproceedings{cai2022marlin,
238+
title = {MARLIN: Masked Autoencoder for facial video Representation LearnINg},
239+
author = {Cai, Zhixi and Ghosh, Shreya and Stefanov, Kalin and Dhall, Abhinav and Cai, Jianfei and Rezatofighi, Hamid and Haffari, Reza and Hayat, Munawar},
240+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
241+
year = {2023},
242+
month = {June},
243+
pages = {1493-1504},
244+
doi = {10.1109/CVPR52729.2023.00150},
245+
publisher = {IEEE},
246+
}
247+
```
248+
249+
## License
250+
251+
This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.
252+
253+
## Acknowledgements
254+
255+
Some code about model is based on [MCG-NJU/VideoMAE](https://github.com/MCG-NJU/VideoMAE). The code related to preprocessing
256+
is borrowed from [JDAI-CV/FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo).

0 commit comments

Comments
 (0)