Skip to content

Commit 0db73ad

Browse files
committed
polish up autoencoder fvq with the best config found by @mahdip72
1 parent 721e1e4 commit 0db73ad

8 files changed

Lines changed: 513 additions & 396 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,14 @@ assert loss.item() >= 0
804804
}
805805
```
806806

807+
```bibtex
808+
@misc{chang2025scalabletrainingvectorquantizednetworks,
809+
title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization},
810+
author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang},
811+
year = {2025},
812+
eprint = {2509.10140},
813+
archivePrefix = {arXiv},
814+
primaryClass = {cs.CV},
815+
url = {https://arxiv.org/abs/2509.10140},
816+
}
817+
```

autoencoder_fvq.py

Lines changed: 0 additions & 206 deletions
This file was deleted.

examples/autoencoder.py

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,87 @@
1-
# FashionMnist VQ experiment with various settings.
2-
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
1+
#!/usr/bin/env uv run
2+
# /// script
3+
# dependencies = [
4+
# "torch",
5+
# "torchvision",
6+
# "tqdm",
7+
# "fire",
8+
# "einops",
9+
# "einx",
10+
# ]
11+
# ///
312

413
from tqdm.auto import trange
514

15+
import fire
616
import torch
717
import torch.nn as nn
818
from torchvision import datasets, transforms
919
from torch.utils.data import DataLoader
20+
from torch.optim import AdamW
1021

1122
from vector_quantize_pytorch import VectorQuantize, Sequential
1223

13-
lr = 3e-4
14-
train_iter = 1000
15-
num_codes = 256
16-
seed = 1234
17-
rotation_trick = True
18-
device = "cuda" if torch.cuda.is_available() else "cpu"
24+
# helpers
1925

20-
def SimpleVQAutoEncoder(**vq_kwargs):
26+
def exists(val):
27+
return val is not None
28+
29+
def default(val, d):
30+
return val if exists(val) else d
31+
32+
# classes
33+
34+
def SimpleVQAutoEncoder(dim = 32, **vq_kwargs):
2135
return Sequential(
22-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23-
nn.MaxPool2d(kernel_size=2, stride=2),
36+
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
37+
nn.MaxPool2d(kernel_size = 2, stride = 2),
2438
nn.GELU(),
25-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26-
nn.MaxPool2d(kernel_size=2, stride=2),
27-
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
28-
nn.Upsample(scale_factor=2, mode="nearest"),
29-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
39+
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
40+
nn.MaxPool2d(kernel_size = 2, stride = 2),
41+
VectorQuantize(dim = dim, accept_image_fmap = True, **vq_kwargs),
42+
nn.Upsample(scale_factor = 2, mode = "nearest"),
43+
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
3044
nn.GELU(),
31-
nn.Upsample(scale_factor=2, mode="nearest"),
32-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
45+
nn.Upsample(scale_factor = 2, mode = "nearest"),
46+
nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1),
47+
)
48+
49+
def train(
50+
train_iter = 1000,
51+
lr = 3e-4,
52+
dim = 32,
53+
num_codes = 256,
54+
seed = 1234,
55+
rotation_trick = True,
56+
straight_through = False,
57+
directional_reparam = False,
58+
alpha = 10,
59+
batch_size = 256
60+
):
61+
torch.random.manual_seed(seed)
62+
device = "cuda" if torch.cuda.is_available() else "cpu"
63+
64+
model = SimpleVQAutoEncoder(
65+
dim = dim,
66+
codebook_size = num_codes,
67+
rotation_trick = rotation_trick,
68+
straight_through = straight_through,
69+
directional_reparam = directional_reparam
70+
).to(device)
71+
72+
opt = AdamW(model.parameters(), lr = lr)
73+
74+
transform = transforms.Compose([
75+
transforms.ToTensor(),
76+
transforms.Normalize((0.5,), (0.5,))
77+
])
78+
79+
train_dataset = DataLoader(
80+
datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform),
81+
batch_size = batch_size,
82+
shuffle = True,
3383
)
3484

35-
def train(model, train_loader, train_iterations=1000, alpha=10):
3685
def iterate_dataset(data_loader):
3786
data_iter = iter(data_loader)
3887
while True:
@@ -43,9 +92,13 @@ def iterate_dataset(data_loader):
4392
x, y = next(data_iter)
4493
yield x.to(device), y.to(device)
4594

46-
for _ in (pbar := trange(train_iterations)):
95+
dl_iter = iterate_dataset(train_dataset)
96+
97+
pbar = trange(train_iter)
98+
99+
for _ in pbar:
47100
opt.zero_grad()
48-
x, _ = next(iterate_dataset(train_loader))
101+
x, _ = next(dl_iter)
49102

50103
out, indices, cmt_loss = model(x)
51104
out = out.clamp(-1., 1.)
@@ -54,31 +107,12 @@ def iterate_dataset(data_loader):
54107
(rec_loss + alpha * cmt_loss).backward()
55108

56109
opt.step()
110+
57111
pbar.set_description(
58112
f"rec loss: {rec_loss.item():.3f} | "
59-
+ f"cmt loss: {cmt_loss.item():.3f} | "
60-
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
113+
f"cmt loss: {cmt_loss.item():.3f} | "
114+
f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
61115
)
62116

63-
transform = transforms.Compose(
64-
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
65-
)
66-
train_dataset = DataLoader(
67-
datasets.FashionMNIST(
68-
root="~/data/fashion_mnist", train=True, download=True, transform=transform
69-
),
70-
batch_size=256,
71-
shuffle=True,
72-
)
73-
74-
torch.random.manual_seed(seed)
75-
76-
model = SimpleVQAutoEncoder(
77-
codebook_size = num_codes,
78-
rotation_trick = False,
79-
straight_through = False,
80-
directional_reparam = True
81-
).to(device)
82-
83-
opt = torch.optim.AdamW(model.parameters(), lr=lr)
84-
train(model, train_dataset, train_iterations=train_iter)
117+
if __name__ == "__main__":
118+
fire.Fire(train)

0 commit comments

Comments
 (0)