Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,14 @@ assert loss.item() >= 0
}
```

```bibtex
@misc{chang2025scalabletrainingvectorquantizednetworks,
title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization},
author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang},
year = {2025},
eprint = {2509.10140},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2509.10140},
}
```
124 changes: 79 additions & 45 deletions examples/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,87 @@
# FashionMnist VQ experiment with various settings.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
#!/usr/bin/env uv run
# /// script
# dependencies = [
# "torch",
# "torchvision",
# "tqdm",
# "fire",
# "einops",
# "einx",
# ]
# ///

from tqdm.auto import trange

import fire
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW

from vector_quantize_pytorch import VectorQuantize, Sequential

lr = 3e-4
train_iter = 1000
num_codes = 256
seed = 1234
rotation_trick = True
device = "cuda" if torch.cuda.is_available() else "cpu"
# helpers

def SimpleVQAutoEncoder(**vq_kwargs):
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# classes

def SimpleVQAutoEncoder(dim = 32, **vq_kwargs):
return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2, stride = 2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2, stride = 2),
VectorQuantize(dim = dim, accept_image_fmap = True, **vq_kwargs),
nn.Upsample(scale_factor = 2, mode = "nearest"),
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor = 2, mode = "nearest"),
nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1),
)

def train(
train_iter = 1000,
lr = 3e-4,
dim = 32,
num_codes = 256,
seed = 1234,
rotation_trick = True,
straight_through = False,
directional_reparam = False,
alpha = 10,
batch_size = 256
):
torch.random.manual_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimpleVQAutoEncoder(
dim = dim,
codebook_size = num_codes,
rotation_trick = rotation_trick,
straight_through = straight_through,
directional_reparam = directional_reparam
).to(device)

opt = AdamW(model.parameters(), lr = lr)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

train_dataset = DataLoader(
datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform),
batch_size = batch_size,
shuffle = True,
)

def train(model, train_loader, train_iterations=1000, alpha=10):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
Expand All @@ -43,9 +92,13 @@ def iterate_dataset(data_loader):
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
dl_iter = iterate_dataset(train_dataset)

pbar = trange(train_iter)

for _ in pbar:
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
x, _ = next(dl_iter)

out, indices, cmt_loss = model(x)
out = out.clamp(-1., 1.)
Expand All @@ -54,31 +107,12 @@ def iterate_dataset(data_loader):
(rec_loss + alpha * cmt_loss).backward()

opt.step()

pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"cmt loss: {cmt_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
f"cmt loss: {cmt_loss.item():.3f} | "
f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)

torch.random.manual_seed(seed)

model = SimpleVQAutoEncoder(
codebook_size = num_codes,
rotation_trick = False,
straight_through = False,
directional_reparam = True
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
if __name__ == "__main__":
fire.Fire(train)
105 changes: 67 additions & 38 deletions examples/autoencoder_fsq.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,81 @@
# FashionMnist VQ experiment with various settings, using FSQ.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
#!/usr/bin/env uv run
# /// script
# dependencies = [
# "torch",
# "torchvision",
# "tqdm",
# "fire",
# "einops",
# "einx",
# ]
# ///

from tqdm.auto import trange

import math
import fire
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW

from vector_quantize_pytorch import FSQ, Sequential

# helpers

lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # target size 2^8, actual size 240
num_codes = math.prod(levels)
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# classes

def SimpleFSQAutoEncoder(levels: list[int]):
return Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2, stride = 2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, len(levels), kernel_size=1),
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2, stride = 2),
nn.Conv2d(32, len(levels), kernel_size = 1),
FSQ(levels),
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.Conv2d(len(levels), 32, kernel_size = 3, stride = 1, padding = 1),
nn.Upsample(scale_factor = 2, mode = "nearest"),
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor = 2, mode = "nearest"),
nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1),
)

def train(
train_iter = 1000,
lr = 3e-4,
levels = [8, 6, 5],
seed = 1234,
batch_size = 256
):
torch.random.manual_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

num_codes = math.prod(levels)

model = SimpleFSQAutoEncoder(levels).to(device)

opt = AdamW(model.parameters(), lr = lr)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

train_dataset = DataLoader(
datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform),
batch_size = batch_size,
shuffle = True,
)

def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
Expand All @@ -49,34 +86,26 @@ def iterate_dataset(data_loader):
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
dl_iter = iterate_dataset(train_dataset)

pbar = trange(train_iter)

for _ in pbar:
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
x, _ = next(dl_iter)

out, indices = model(x)
out = out.clamp(-1., 1.)

rec_loss = (out - x).abs().mean()
rec_loss.backward()

opt.step()

pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)


transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)

torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
if __name__ == "__main__":
fire.Fire(train)
Loading
Loading