diff --git a/README.md b/README.md index d916d3e..a131bef 100644 --- a/README.md +++ b/README.md @@ -804,3 +804,14 @@ assert loss.item() >= 0 } ``` +```bibtex +@misc{chang2025scalabletrainingvectorquantizednetworks, + title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization}, + author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang}, + year = {2025}, + eprint = {2509.10140}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV}, + url = {https://arxiv.org/abs/2509.10140}, +} +``` diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 09b8332..3e486e3 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -1,38 +1,87 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange +import fire import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import VectorQuantize, Sequential -lr = 3e-4 -train_iter = 1000 -num_codes = 256 -seed = 1234 -rotation_trick = True -device = "cuda" if torch.cuda.is_available() else "cpu" +# helpers -def SimpleVQAutoEncoder(**vq_kwargs): +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes + +def SimpleVQAutoEncoder(dim = 32, **vq_kwargs): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + VectorQuantize(dim = dim, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + rotation_trick = True, + straight_through = False, + directional_reparam = False, + alpha = 10, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = SimpleVQAutoEncoder( + dim = dim, + codebook_size = num_codes, + rotation_trick = rotation_trick, + straight_through = straight_through, + directional_reparam = directional_reparam + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000, alpha=10): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -43,9 +92,13 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) out, indices, cmt_loss = model(x) out = out.clamp(-1., 1.) @@ -54,31 +107,12 @@ def iterate_dataset(data_loader): (rec_loss + alpha * cmt_loss).backward() opt.step() + pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"cmt loss: {cmt_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"cmt loss: {cmt_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = SimpleVQAutoEncoder( - codebook_size = num_codes, - rotation_trick = False, - straight_through = False, - directional_reparam = True -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py index 56e3c6e..e418814 100644 --- a/examples/autoencoder_fsq.py +++ b/examples/autoencoder_fsq.py @@ -1,44 +1,81 @@ -# FashionMnist VQ experiment with various settings, using FSQ. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange import math +import fire import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import FSQ, Sequential +# helpers -lr = 3e-4 -train_iter = 1000 -levels = [8, 6, 5] # target size 2^8, actual size 240 -num_codes = math.prod(levels) -seed = 1234 -device = "cuda" if torch.cuda.is_available() else "cpu" +def exists(val): + return val is not None +def default(val, d): + return val if exists(val) else d + +# classes def SimpleFSQAutoEncoder(levels: list[int]): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(32, len(levels), kernel_size=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(32, len(levels), kernel_size = 1), FSQ(levels), - nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(len(levels), 32, kernel_size = 3, stride = 1, padding = 1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), ) +def train( + train_iter = 1000, + lr = 3e-4, + levels = [8, 6, 5], + seed = 1234, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + num_codes = math.prod(levels) + + model = SimpleFSQAutoEncoder(levels).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, + ) -def train(model, train_loader, train_iterations=1000): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -49,9 +86,14 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) + out, indices = model(x) out = out.clamp(-1., 1.) @@ -59,24 +101,11 @@ def iterate_dataset(data_loader): rec_loss.backward() opt.step() + pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) - -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) -model = SimpleFSQAutoEncoder(levels).to(device) -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_fvq.py b/examples/autoencoder_fvq.py new file mode 100644 index 0000000..c0c4f7c --- /dev/null +++ b/examples/autoencoder_fvq.py @@ -0,0 +1,197 @@ +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# "x-transformers==2.16.1", +# ] +# /// + +from tqdm.auto import trange + +import fire + +import torch +import torch.nn as nn +from torch.nn import Module +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from torch.optim import SGD, AdamW + +from einops import rearrange +from einops.layers.torch import Rearrange + +from x_transformers import ContinuousTransformerWrapper, Encoder +from vector_quantize_pytorch import VectorQuantize, Sequential + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes + +def VQBridgeViT( + dim, + depth, + input_dim = None, + patch_size = 1, + dim_head = 64, + heads = 4, + num_registers = 0 +): + # Credit goes to Mahdi (@mahdip72) for his experiments that found the best + # set of hyperparameters for the ViT used in FVQ, which is patch_size 1 (becomes a regular transformer encoder) and critically, having register tokens (we will do 2 here) + # see experiments at https://github.com/lucidrains/vector-quantize-pytorch/issues/239#issuecomment-3888240360 + + input_dim = default(input_dim, dim) + + project_in_out_kwargs = dict() + + inner_dim = input_dim * patch_size + + if patch_size > 1 or inner_dim != dim: + project_in_out_kwargs.update( + project_in = nn.Sequential( + Rearrange('b (n p) c -> b n (p c)', p = patch_size), + nn.Linear(inner_dim, dim, bias = False) + ), + project_out = nn.Sequential( + nn.Linear(dim, inner_dim, bias = False), + Rearrange('b n (p c) -> b (n p) c', p = patch_size) + ) + ) + + return ContinuousTransformerWrapper( + num_memory_tokens = num_registers, + attn_layers = Encoder( + dim = dim, + depth = depth, + heads = heads, + attn_dim_head = dim_head, + pre_norm_has_final_norm = False + ), + **project_in_out_kwargs + ) + +def SimpleVQAutoEncoder( + dim = 32, + vq_bridge: Module | None = None, + rotation_trick = True, + **vq_kwargs +): + return Sequential( + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + VectorQuantize( + dim = dim, + accept_image_fmap = True, + rotation_trick = rotation_trick, + vq_bridge = vq_bridge, + **vq_kwargs + ), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), + nn.GELU(), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + patch_size = 1, + no_bridge = False, + rotation_trick = False, + num_registers = 2, + heads = 4, + vq_dim = 256, + vq_depth = 1, + alpha = 10, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + vq_bridge = None + + if not no_bridge: + vq_bridge = VQBridgeViT( + dim = vq_dim, + depth = vq_depth, + input_dim = dim, + patch_size = patch_size, + heads = heads, + num_registers = num_registers + ) + + model = SimpleVQAutoEncoder( + dim = dim, + vq_bridge = vq_bridge, + rotation_trick = rotation_trick + codebook_size = num_codes, + learnable_codebook = True, + in_place_codebook_optimizer = lambda *args, **kwargs: SGD(*args, **kwargs, lr = 1e-3), + ema_update = False, + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, + ) + + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: + opt.zero_grad() + x, _ = next(dl_iter) + + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + + rec_loss = (out - x).abs().mean() + (rec_loss + alpha * cmt_loss).backward() + + opt.step() + + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + f"cmt loss: {cmt_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_lfq.py b/examples/autoencoder_lfq.py index 7ffdb7c..4c187fe 100644 --- a/examples/autoencoder_lfq.py +++ b/examples/autoencoder_lfq.py @@ -1,26 +1,37 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange from math import log2 +import fire import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import LFQ, Sequential -lr = 3e-4 -train_iter = 1000 -seed = 1234 -codebook_size = 2 ** 8 -entropy_loss_weight = 0.02 -diversity_gamma = 1. -spherical = True +# helpers -device = "cuda" if torch.cuda.is_available() else "cpu" +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes def LFQAutoEncoder( codebook_size, @@ -30,25 +41,55 @@ def LFQAutoEncoder( quantize_dim = int(log2(codebook_size)) return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - # In general norm layers are commonly used in Resnet-based encoder/decoders - # explicitly add one here with affine=False to avoid introducing new parameters - nn.GroupNorm(4, 32, affine=False), - nn.Conv2d(32, quantize_dim, kernel_size=1), - LFQ(dim=quantize_dim, **vq_kwargs), - nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.GroupNorm(4, 32, affine = False), + nn.Conv2d(32, quantize_dim, kernel_size = 1), + LFQ(dim = quantize_dim, **vq_kwargs), + nn.Conv2d(quantize_dim, 32, kernel_size = 3, stride = 1, padding = 1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + codebook_size = 256, + seed = 1234, + entropy_loss_weight = 0.02, + diversity_gamma = 1., + spherical = True, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = LFQAutoEncoder( + codebook_size = codebook_size, + entropy_loss_weight = entropy_loss_weight, + diversity_gamma = diversity_gamma, + spherical = spherical + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -59,9 +100,14 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) + out, indices, entropy_aux_loss = model(x) out = out.clamp(-1., 1.) @@ -69,33 +115,12 @@ def iterate_dataset(data_loader): (rec_loss + entropy_aux_loss).backward() opt.step() + pbar.set_description( - f"rec loss: {rec_loss.item():.3f} | " - + f"entropy aux loss: {entropy_aux_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}" + f"rec loss: {rec_loss.item():.3f} | " + f"entropy aux loss: {entropy_aux_loss.item():.3f} | " + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) - -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = LFQAutoEncoder( - codebook_size = codebook_size, - entropy_loss_weight = entropy_loss_weight, - diversity_gamma = diversity_gamma, - spherical = spherical -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) - -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index 535b96d..4c82a5d 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -1,5 +1,14 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange @@ -7,35 +16,73 @@ import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW +import fire from vector_quantize_pytorch import SimVQ, Sequential -lr = 3e-4 -train_iter = 10000 -num_codes = 256 -seed = 1234 +# helpers -rotation_trick = True # rotation trick instead ot straight-through -use_mlp = True # use a one layer mlp with relu instead of linear +def exists(val): + return val is not None -device = "cuda" if torch.cuda.is_available() else "cpu" +def default(val, d): + return val if exists(val) else d -def SimVQAutoEncoder(**vq_kwargs): +# classes + +def SimpleSimVQAutoEncoder(dim = 32, **vq_kwargs): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - SimVQ(dim=32, channel_first = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + SimVQ(dim = dim, channel_first = True, **vq_kwargs), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + rotation_trick = True, + use_mlp = True, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = SimpleSimVQAutoEncoder( + dim = dim, + codebook_size = num_codes, + rotation_trick = rotation_trick, + codebook_transform = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.ReLU(), + nn.Linear(dim * 4, dim), + ) if use_mlp else None + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000, alpha=10): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -46,47 +93,27 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) - out, indices, cmt_loss = model(x) + out, indices, sim_loss = model(x) out = out.clamp(-1., 1.) rec_loss = (out - x).abs().mean() - (rec_loss + alpha * cmt_loss).backward() + (rec_loss + sim_loss).backward() opt.step() pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"cmt loss: {cmt_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"sim loss: {sim_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) - -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = SimVQAutoEncoder( - codebook_size = num_codes, - rotation_trick = rotation_trick, - codebook_transform = nn.Sequential( - nn.Linear(32, 128), - nn.ReLU(), - nn.Linear(128, 32), - ) if use_mlp else None -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/pyproject.toml b/pyproject.toml index df17d34..b8939d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.27.20" +version = "1.27.21" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -41,15 +41,17 @@ build-backend = "hatchling.build" [tool.rye] managed = true + dev-dependencies = [ "ruff>=0.4.2", "pytest>=8.2.0", "pytest-cov>=5.0.0", + "x-transformers>=2.16.1", ] [tool.pytest.ini_options] pythonpath = [ - "." + "." ] [tool.hatch.metadata] diff --git a/tests/test_lfq.py b/tests/test_lfq.py index 0eb5903..32ccfc4 100644 --- a/tests/test_lfq.py +++ b/tests/test_lfq.py @@ -1,36 +1,36 @@ import torch import pytest -from vector_quantize_pytorch import LFQ import math -""" -testing_strategy: -subdivisions: using masks, using frac_per_sample_entropy < 1 -""" +from vector_quantize_pytorch import LFQ + +# helpers -torch.manual_seed(0) +def exists(val): + return val is not None + +# tests @pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5)) -@pytest.mark.parametrize('mask', (torch.tensor([False, False]), - torch.tensor([True, False]), - torch.tensor([True, True]))) +@pytest.mark.parametrize('mask', ( + torch.tensor([False, False]), + torch.tensor([True, False]), + torch.tensor([True, True]) +)) def test_masked_lfq( frac_per_sample_entropy, mask ): - # you can specify either dim or codebook_size - # if both specified, will be validated against each other - quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = frac_per_sample_entropy ) image_feats = torch.randn(2, 16, 32, 32) - ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + ret, _ = quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) quantized, indices, _ = ret assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -38,40 +38,40 @@ def test_masked_lfq( @pytest.mark.parametrize('frac_per_sample_entropy', (0.1,)) @pytest.mark.parametrize('iters', (10,)) @pytest.mark.parametrize('mask', (None, torch.tensor([True, False]))) -def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask): +def test_lfq_bruteforce_frac_per_sample_entropy( + frac_per_sample_entropy, + iters, + mask +): image_feats = torch.randn(2, 16, 32, 32) full_per_sample_entropy_quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = 1 ) partial_per_sample_entropy_quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = frac_per_sample_entropy ) - ret, loss_breakdown = full_per_sample_entropy_quantizer( - image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) + ret, loss_breakdown = full_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) true_per_sample_entropy = loss_breakdown.per_sample_entropy per_sample_losses = torch.zeros(iters) - for iter in range(iters): - ret, loss_breakdown = partial_per_sample_entropy_quantizer( - image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + + for i in range(iters): + ret, loss_breakdown = partial_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) quantized, indices, _ = ret assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all() - per_sample_losses[iter] = loss_breakdown.per_sample_entropy - # 95% confidence interval - assert abs(per_sample_losses.mean() - true_per_sample_entropy) \ - < (1.96*(per_sample_losses.std() / math.sqrt(iters))) + per_sample_losses[i] = loss_breakdown.per_sample_entropy - print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy)) - print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters)))) \ No newline at end of file + # 95% confidence interval + assert abs(per_sample_losses.mean() - true_per_sample_entropy) < (1.96 * (per_sample_losses.std() / math.sqrt(iters))) \ No newline at end of file diff --git a/tests/test_readme.py b/tests/test_readme.py index f861efe..4d5fe1f 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -506,3 +506,32 @@ def test_vq_3d(): assert x.shape == quantized.shape assert indices.shape == (1, 16, 16, 16) assert torch.allclose(quantized, quantizer.get_output_from_indices(indices)) + +def test_fvq(): + from vector_quantize_pytorch import VectorQuantize + from x_transformers import ContinuousTransformerWrapper, Encoder + + vq_bridge = ContinuousTransformerWrapper( + num_memory_tokens = 2, + attn_layers = Encoder( + dim = 256, + depth = 1, + heads = 4, + pre_norm_has_final_norm = False + ) + ) + + vq = VectorQuantize( + dim = 256, + codebook_size = 512, + vq_bridge = vq_bridge, + learnable_codebook = True, + ema_update = False, + in_place_codebook_optimizer = lambda params: torch.optim.SGD(params, lr = 1e-3) + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + + assert quantized.shape == x.shape + assert indices.shape == (1, 1024) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 98aeff2..dddcdb4 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -366,7 +366,8 @@ def __init__( sync_affine_param = False, affine_param_batch_decay = 0.99, affine_param_codebook_decay = 0.9, - use_cosine_sim = False + use_cosine_sim = False, + vq_bridge: Module | None = None ): super().__init__() self.transform_input = identity if not use_cosine_sim else l2norm @@ -395,6 +396,7 @@ def __init__( self.gumbel_sample = gumbel_sample self.sample_codebook_temp = sample_codebook_temp + self.use_ddp = use_ddp assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now' self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors @@ -416,6 +418,10 @@ def __init__( self.use_cosine_sim = use_cosine_sim + # fvq + + self.vq_bridge = vq_bridge + # affine related params self.affine_param = affine_param @@ -668,6 +674,11 @@ def forward( embed = embed.to(dtype) + # maybe vq bridge + + if exists(self.vq_bridge): + embed = self.vq_bridge(embed) + # affine params if self.affine_param: @@ -787,6 +798,7 @@ def __init__( sync_codebook = None, sync_affine_param = False, ema_update = None, + vq_bridge: Module | None = None, manual_ema_update = False, learnable_codebook = None, in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook @@ -801,8 +813,8 @@ def __init__( # defaults - ema_update = default(ema_update, not directional_reparam) - learnable_codebook = default(learnable_codebook, directional_reparam) + ema_update = default(ema_update, not directional_reparam and not exists(vq_bridge)) + learnable_codebook = default(learnable_codebook, directional_reparam or exists(vq_bridge)) rotation_trick = default(rotation_trick, not directional_reparam and dim > 1) # only use rotation trick if feature dimension greater than 1 # basic variables @@ -853,6 +865,8 @@ def __init__( assert not (straight_through and learnable_codebook), 'gumbel straight through not allowed when learning the codebook' assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update' + assert not (exists(vq_bridge) and not learnable_codebook), 'learnable_codebook must be set to True if vq_bridge is passed in' + assert not (exists(vq_bridge) and ema_update), 'ema_update must be False if vq_bridge is passed in' assert 0 <= sync_update_v <= 1. assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on' @@ -886,7 +900,8 @@ def __init__( gumbel_sample = gumbel_sample_fn, ema_update = ema_update, manual_ema_update = manual_ema_update, - use_cosine_sim = use_cosine_sim + use_cosine_sim = use_cosine_sim, + vq_bridge = vq_bridge ) if affine_param: