Skip to content

Commit 20bfdf1

Browse files
authored
Merge pull request #242 from lucidrains/test-fvq
Test fvq
2 parents 7a4e13c + 8c3360e commit 20bfdf1

10 files changed

Lines changed: 601 additions & 232 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,14 @@ assert loss.item() >= 0
804804
}
805805
```
806806

807+
```bibtex
808+
@misc{chang2025scalabletrainingvectorquantizednetworks,
809+
title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization},
810+
author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang},
811+
year = {2025},
812+
eprint = {2509.10140},
813+
archivePrefix = {arXiv},
814+
primaryClass = {cs.CV},
815+
url = {https://arxiv.org/abs/2509.10140},
816+
}
817+
```

examples/autoencoder.py

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,87 @@
1-
# FashionMnist VQ experiment with various settings.
2-
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
1+
#!/usr/bin/env uv run
2+
# /// script
3+
# dependencies = [
4+
# "torch",
5+
# "torchvision",
6+
# "tqdm",
7+
# "fire",
8+
# "einops",
9+
# "einx",
10+
# ]
11+
# ///
312

413
from tqdm.auto import trange
514

15+
import fire
616
import torch
717
import torch.nn as nn
818
from torchvision import datasets, transforms
919
from torch.utils.data import DataLoader
20+
from torch.optim import AdamW
1021

1122
from vector_quantize_pytorch import VectorQuantize, Sequential
1223

13-
lr = 3e-4
14-
train_iter = 1000
15-
num_codes = 256
16-
seed = 1234
17-
rotation_trick = True
18-
device = "cuda" if torch.cuda.is_available() else "cpu"
24+
# helpers
1925

20-
def SimpleVQAutoEncoder(**vq_kwargs):
26+
def exists(val):
27+
return val is not None
28+
29+
def default(val, d):
30+
return val if exists(val) else d
31+
32+
# classes
33+
34+
def SimpleVQAutoEncoder(dim = 32, **vq_kwargs):
2135
return Sequential(
22-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23-
nn.MaxPool2d(kernel_size=2, stride=2),
36+
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
37+
nn.MaxPool2d(kernel_size = 2, stride = 2),
2438
nn.GELU(),
25-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26-
nn.MaxPool2d(kernel_size=2, stride=2),
27-
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
28-
nn.Upsample(scale_factor=2, mode="nearest"),
29-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
39+
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
40+
nn.MaxPool2d(kernel_size = 2, stride = 2),
41+
VectorQuantize(dim = dim, accept_image_fmap = True, **vq_kwargs),
42+
nn.Upsample(scale_factor = 2, mode = "nearest"),
43+
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
3044
nn.GELU(),
31-
nn.Upsample(scale_factor=2, mode="nearest"),
32-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
45+
nn.Upsample(scale_factor = 2, mode = "nearest"),
46+
nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1),
47+
)
48+
49+
def train(
50+
train_iter = 1000,
51+
lr = 3e-4,
52+
dim = 32,
53+
num_codes = 256,
54+
seed = 1234,
55+
rotation_trick = True,
56+
straight_through = False,
57+
directional_reparam = False,
58+
alpha = 10,
59+
batch_size = 256
60+
):
61+
torch.random.manual_seed(seed)
62+
device = "cuda" if torch.cuda.is_available() else "cpu"
63+
64+
model = SimpleVQAutoEncoder(
65+
dim = dim,
66+
codebook_size = num_codes,
67+
rotation_trick = rotation_trick,
68+
straight_through = straight_through,
69+
directional_reparam = directional_reparam
70+
).to(device)
71+
72+
opt = AdamW(model.parameters(), lr = lr)
73+
74+
transform = transforms.Compose([
75+
transforms.ToTensor(),
76+
transforms.Normalize((0.5,), (0.5,))
77+
])
78+
79+
train_dataset = DataLoader(
80+
datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform),
81+
batch_size = batch_size,
82+
shuffle = True,
3383
)
3484

35-
def train(model, train_loader, train_iterations=1000, alpha=10):
3685
def iterate_dataset(data_loader):
3786
data_iter = iter(data_loader)
3887
while True:
@@ -43,9 +92,13 @@ def iterate_dataset(data_loader):
4392
x, y = next(data_iter)
4493
yield x.to(device), y.to(device)
4594

46-
for _ in (pbar := trange(train_iterations)):
95+
dl_iter = iterate_dataset(train_dataset)
96+
97+
pbar = trange(train_iter)
98+
99+
for _ in pbar:
47100
opt.zero_grad()
48-
x, _ = next(iterate_dataset(train_loader))
101+
x, _ = next(dl_iter)
49102

50103
out, indices, cmt_loss = model(x)
51104
out = out.clamp(-1., 1.)
@@ -54,31 +107,12 @@ def iterate_dataset(data_loader):
54107
(rec_loss + alpha * cmt_loss).backward()
55108

56109
opt.step()
110+
57111
pbar.set_description(
58112
f"rec loss: {rec_loss.item():.3f} | "
59-
+ f"cmt loss: {cmt_loss.item():.3f} | "
60-
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
113+
f"cmt loss: {cmt_loss.item():.3f} | "
114+
f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
61115
)
62116

63-
transform = transforms.Compose(
64-
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
65-
)
66-
train_dataset = DataLoader(
67-
datasets.FashionMNIST(
68-
root="~/data/fashion_mnist", train=True, download=True, transform=transform
69-
),
70-
batch_size=256,
71-
shuffle=True,
72-
)
73-
74-
torch.random.manual_seed(seed)
75-
76-
model = SimpleVQAutoEncoder(
77-
codebook_size = num_codes,
78-
rotation_trick = False,
79-
straight_through = False,
80-
directional_reparam = True
81-
).to(device)
82-
83-
opt = torch.optim.AdamW(model.parameters(), lr=lr)
84-
train(model, train_dataset, train_iterations=train_iter)
117+
if __name__ == "__main__":
118+
fire.Fire(train)

examples/autoencoder_fsq.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,81 @@
1-
# FashionMnist VQ experiment with various settings, using FSQ.
2-
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
1+
#!/usr/bin/env uv run
2+
# /// script
3+
# dependencies = [
4+
# "torch",
5+
# "torchvision",
6+
# "tqdm",
7+
# "fire",
8+
# "einops",
9+
# "einx",
10+
# ]
11+
# ///
312

413
from tqdm.auto import trange
514

615
import math
16+
import fire
717
import torch
818
import torch.nn as nn
919
from torchvision import datasets, transforms
1020
from torch.utils.data import DataLoader
21+
from torch.optim import AdamW
1122

1223
from vector_quantize_pytorch import FSQ, Sequential
1324

25+
# helpers
1426

15-
lr = 3e-4
16-
train_iter = 1000
17-
levels = [8, 6, 5] # target size 2^8, actual size 240
18-
num_codes = math.prod(levels)
19-
seed = 1234
20-
device = "cuda" if torch.cuda.is_available() else "cpu"
27+
def exists(val):
28+
return val is not None
2129

30+
def default(val, d):
31+
return val if exists(val) else d
32+
33+
# classes
2234

2335
def SimpleFSQAutoEncoder(levels: list[int]):
2436
return Sequential(
25-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
26-
nn.MaxPool2d(kernel_size=2, stride=2),
37+
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
38+
nn.MaxPool2d(kernel_size = 2, stride = 2),
2739
nn.GELU(),
28-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
29-
nn.MaxPool2d(kernel_size=2, stride=2),
30-
nn.Conv2d(32, len(levels), kernel_size=1),
40+
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
41+
nn.MaxPool2d(kernel_size = 2, stride = 2),
42+
nn.Conv2d(32, len(levels), kernel_size = 1),
3143
FSQ(levels),
32-
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
33-
nn.Upsample(scale_factor=2, mode="nearest"),
34-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
44+
nn.Conv2d(len(levels), 32, kernel_size = 3, stride = 1, padding = 1),
45+
nn.Upsample(scale_factor = 2, mode = "nearest"),
46+
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
3547
nn.GELU(),
36-
nn.Upsample(scale_factor=2, mode="nearest"),
37-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
48+
nn.Upsample(scale_factor = 2, mode = "nearest"),
49+
nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1),
3850
)
3951

52+
def train(
53+
train_iter = 1000,
54+
lr = 3e-4,
55+
levels = [8, 6, 5],
56+
seed = 1234,
57+
batch_size = 256
58+
):
59+
torch.random.manual_seed(seed)
60+
device = "cuda" if torch.cuda.is_available() else "cpu"
61+
62+
num_codes = math.prod(levels)
63+
64+
model = SimpleFSQAutoEncoder(levels).to(device)
65+
66+
opt = AdamW(model.parameters(), lr = lr)
67+
68+
transform = transforms.Compose([
69+
transforms.ToTensor(),
70+
transforms.Normalize((0.5,), (0.5,))
71+
])
72+
73+
train_dataset = DataLoader(
74+
datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform),
75+
batch_size = batch_size,
76+
shuffle = True,
77+
)
4078

41-
def train(model, train_loader, train_iterations=1000):
4279
def iterate_dataset(data_loader):
4380
data_iter = iter(data_loader)
4481
while True:
@@ -49,34 +86,26 @@ def iterate_dataset(data_loader):
4986
x, y = next(data_iter)
5087
yield x.to(device), y.to(device)
5188

52-
for _ in (pbar := trange(train_iterations)):
89+
dl_iter = iterate_dataset(train_dataset)
90+
91+
pbar = trange(train_iter)
92+
93+
for _ in pbar:
5394
opt.zero_grad()
54-
x, _ = next(iterate_dataset(train_loader))
95+
x, _ = next(dl_iter)
96+
5597
out, indices = model(x)
5698
out = out.clamp(-1., 1.)
5799

58100
rec_loss = (out - x).abs().mean()
59101
rec_loss.backward()
60102

61103
opt.step()
104+
62105
pbar.set_description(
63106
f"rec loss: {rec_loss.item():.3f} | "
64-
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
107+
f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
65108
)
66109

67-
68-
transform = transforms.Compose(
69-
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
70-
)
71-
train_dataset = DataLoader(
72-
datasets.FashionMNIST(
73-
root="~/data/fashion_mnist", train=True, download=True, transform=transform
74-
),
75-
batch_size=256,
76-
shuffle=True,
77-
)
78-
79-
torch.random.manual_seed(seed)
80-
model = SimpleFSQAutoEncoder(levels).to(device)
81-
opt = torch.optim.AdamW(model.parameters(), lr=lr)
82-
train(model, train_dataset, train_iterations=train_iter)
110+
if __name__ == "__main__":
111+
fire.Fire(train)

0 commit comments

Comments
 (0)