From 721e1e41655f6413bf806f70b3caaaa907870a79 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 1 Jan 2026 07:27:49 -0800 Subject: [PATCH 1/4] fvq --- autoencoder_fvq.py | 206 ++++++++++++++++++ .../vector_quantize_pytorch.py | 17 +- 2 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 autoencoder_fvq.py diff --git a/autoencoder_fvq.py b/autoencoder_fvq.py new file mode 100644 index 0000000..25174b8 --- /dev/null +++ b/autoencoder_fvq.py @@ -0,0 +1,206 @@ +# FashionMnist VQ experiment with various settings. +# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py + +from tqdm.auto import trange + +import torch +import torch.nn as nn +from torch.nn import Module +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from torch.optim import SGD + +from vector_quantize_pytorch import VectorQuantize, Sequential + +lr = 3e-4 +train_iter = 1000 +num_codes = 256 +seed = 1234 +rotation_trick = True +device = "cuda" if torch.cuda.is_available() else "cpu" + +import torch +from torch import nn + +from einops import rearrange, repeat, pack, unpack +from einops.layers.torch import Rearrange + +# classes + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +class VQBridgeViT(Module): + def __init__( + self, + dim, + depth, + input_dim = 32, + patch_size = 16, + dim_head = 16, + heads = 4, + add_residual = False + ): + super().__init__() + self.add_residual = add_residual + + patch_dim = input_dim * patch_size + self.patch_to_tokens = nn.Sequential( + Rearrange('b (n p) c -> b n (p c)', p = patch_size), + nn.Linear(patch_dim, dim), + ) + + self.transformer = Transformer(dim = dim, dim_head = dim_head, heads = heads, depth = depth, mlp_dim = dim * 4) + + self.tokens_to_patch = nn.Sequential( + nn.Linear(dim, patch_dim), + Rearrange('b n (p c) -> b (n p) c', p = patch_size), + ) + + def forward(self, x): + residual = x + + x = self.patch_to_tokens(x) + + x = self.transformer(x) + + x = self.tokens_to_patch(x) + + if self.add_residual: + x = x + residual + + return x + +def SimpleVQAutoEncoder(**vq_kwargs): + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) + +def train(model, train_loader, train_iterations=1000, alpha=10): + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + for _ in (pbar := trange(train_iterations)): + opt.zero_grad() + x, _ = next(iterate_dataset(train_loader)) + + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + + rec_loss = (out - x).abs().mean() + (rec_loss + alpha * cmt_loss).backward() + + opt.step() + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + + f"cmt loss: {cmt_loss.item():.3f} | " + + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) +train_dataset = DataLoader( + datasets.FashionMNIST( + root="~/data/fashion_mnist", train=True, download=True, transform=transform + ), + batch_size=256, + shuffle=True, +) + +torch.random.manual_seed(seed) + +model = SimpleVQAutoEncoder( + codebook_size = num_codes, + learnable_codebook = True, + in_place_codebook_optimizer = lambda *args, **kwargs: SGD(*args, **kwargs, lr = 1e-3), + ema_update = False, + # vq_bridge = None, + vq_bridge = VQBridgeViT( + dim = 256, + input_dim = 32, + patch_size = 2, + depth = 1, + add_residual = False + ) +).to(device) + +opt = torch.optim.AdamW(model.parameters(), lr=lr) +train(model, train_dataset, train_iterations=train_iter) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 98aeff2..2baa83e 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -366,7 +366,8 @@ def __init__( sync_affine_param = False, affine_param_batch_decay = 0.99, affine_param_codebook_decay = 0.9, - use_cosine_sim = False + use_cosine_sim = False, + vq_bridge: Module | None = None ): super().__init__() self.transform_input = identity if not use_cosine_sim else l2norm @@ -395,6 +396,7 @@ def __init__( self.gumbel_sample = gumbel_sample self.sample_codebook_temp = sample_codebook_temp + self.use_ddp = use_ddp assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now' self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors @@ -416,6 +418,10 @@ def __init__( self.use_cosine_sim = use_cosine_sim + # fvq + + self.vq_bridge = vq_bridge + # affine related params self.affine_param = affine_param @@ -668,6 +674,11 @@ def forward( embed = embed.to(dtype) + # maybe vq bridge + + if exists(self.vq_bridge): + embed = self.vq_bridge(embed) + # affine params if self.affine_param: @@ -787,6 +798,7 @@ def __init__( sync_codebook = None, sync_affine_param = False, ema_update = None, + vq_bridge: Module | None = None, manual_ema_update = False, learnable_codebook = None, in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook @@ -886,7 +898,8 @@ def __init__( gumbel_sample = gumbel_sample_fn, ema_update = ema_update, manual_ema_update = manual_ema_update, - use_cosine_sim = use_cosine_sim + use_cosine_sim = use_cosine_sim, + vq_bridge = vq_bridge ) if affine_param: From 0db73ad0b72248d64bc1fa4f22c3b996f11ff0df Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 12 Feb 2026 08:47:39 -0800 Subject: [PATCH 2/4] polish up autoencoder fvq with the best config found by @mahdip72 --- README.md | 11 ++ autoencoder_fvq.py | 206 --------------------------------- examples/autoencoder.py | 124 +++++++++++++------- examples/autoencoder_fsq.py | 105 +++++++++++------ examples/autoencoder_fvq.py | 197 +++++++++++++++++++++++++++++++ examples/autoencoder_lfq.py | 133 ++++++++++++--------- examples/autoencoder_sim_vq.py | 131 ++++++++++++--------- pyproject.toml | 2 +- 8 files changed, 513 insertions(+), 396 deletions(-) delete mode 100644 autoencoder_fvq.py create mode 100644 examples/autoencoder_fvq.py diff --git a/README.md b/README.md index d916d3e..a131bef 100644 --- a/README.md +++ b/README.md @@ -804,3 +804,14 @@ assert loss.item() >= 0 } ``` +```bibtex +@misc{chang2025scalabletrainingvectorquantizednetworks, + title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization}, + author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang}, + year = {2025}, + eprint = {2509.10140}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV}, + url = {https://arxiv.org/abs/2509.10140}, +} +``` diff --git a/autoencoder_fvq.py b/autoencoder_fvq.py deleted file mode 100644 index 25174b8..0000000 --- a/autoencoder_fvq.py +++ /dev/null @@ -1,206 +0,0 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py - -from tqdm.auto import trange - -import torch -import torch.nn as nn -from torch.nn import Module -from torchvision import datasets, transforms -from torch.utils.data import DataLoader -from torch.optim import SGD - -from vector_quantize_pytorch import VectorQuantize, Sequential - -lr = 3e-4 -train_iter = 1000 -num_codes = 256 -seed = 1234 -rotation_trick = True -device = "cuda" if torch.cuda.is_available() else "cpu" - -import torch -from torch import nn - -from einops import rearrange, repeat, pack, unpack -from einops.layers.torch import Rearrange - -# classes - -class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout = 0.): - super().__init__() - self.net = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) - ) - def forward(self, x): - return self.net(x) - -class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): - super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - - self.heads = heads - self.scale = dim_head ** -0.5 - - self.norm = nn.LayerNorm(dim) - self.attend = nn.Softmax(dim = -1) - self.dropout = nn.Dropout(dropout) - - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) if project_out else nn.Identity() - - def forward(self, x): - x = self.norm(x) - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) - - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - - attn = self.attend(dots) - attn = self.dropout(attn) - - out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) - -class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), - FeedForward(dim, mlp_dim, dropout = dropout) - ])) - def forward(self, x): - for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x - return x - -class VQBridgeViT(Module): - def __init__( - self, - dim, - depth, - input_dim = 32, - patch_size = 16, - dim_head = 16, - heads = 4, - add_residual = False - ): - super().__init__() - self.add_residual = add_residual - - patch_dim = input_dim * patch_size - self.patch_to_tokens = nn.Sequential( - Rearrange('b (n p) c -> b n (p c)', p = patch_size), - nn.Linear(patch_dim, dim), - ) - - self.transformer = Transformer(dim = dim, dim_head = dim_head, heads = heads, depth = depth, mlp_dim = dim * 4) - - self.tokens_to_patch = nn.Sequential( - nn.Linear(dim, patch_dim), - Rearrange('b n (p c) -> b (n p) c', p = patch_size), - ) - - def forward(self, x): - residual = x - - x = self.patch_to_tokens(x) - - x = self.transformer(x) - - x = self.tokens_to_patch(x) - - if self.add_residual: - x = x + residual - - return x - -def SimpleVQAutoEncoder(**vq_kwargs): - return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), - nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), - ) - -def train(model, train_loader, train_iterations=1000, alpha=10): - def iterate_dataset(data_loader): - data_iter = iter(data_loader) - while True: - try: - x, y = next(data_iter) - except StopIteration: - data_iter = iter(data_loader) - x, y = next(data_iter) - yield x.to(device), y.to(device) - - for _ in (pbar := trange(train_iterations)): - opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) - - out, indices, cmt_loss = model(x) - out = out.clamp(-1., 1.) - - rec_loss = (out - x).abs().mean() - (rec_loss + alpha * cmt_loss).backward() - - opt.step() - pbar.set_description( - f"rec loss: {rec_loss.item():.3f} | " - + f"cmt loss: {cmt_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" - ) - -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = SimpleVQAutoEncoder( - codebook_size = num_codes, - learnable_codebook = True, - in_place_codebook_optimizer = lambda *args, **kwargs: SGD(*args, **kwargs, lr = 1e-3), - ema_update = False, - # vq_bridge = None, - vq_bridge = VQBridgeViT( - dim = 256, - input_dim = 32, - patch_size = 2, - depth = 1, - add_residual = False - ) -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 09b8332..3e486e3 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -1,38 +1,87 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange +import fire import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import VectorQuantize, Sequential -lr = 3e-4 -train_iter = 1000 -num_codes = 256 -seed = 1234 -rotation_trick = True -device = "cuda" if torch.cuda.is_available() else "cpu" +# helpers -def SimpleVQAutoEncoder(**vq_kwargs): +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes + +def SimpleVQAutoEncoder(dim = 32, **vq_kwargs): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + VectorQuantize(dim = dim, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + rotation_trick = True, + straight_through = False, + directional_reparam = False, + alpha = 10, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = SimpleVQAutoEncoder( + dim = dim, + codebook_size = num_codes, + rotation_trick = rotation_trick, + straight_through = straight_through, + directional_reparam = directional_reparam + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000, alpha=10): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -43,9 +92,13 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) out, indices, cmt_loss = model(x) out = out.clamp(-1., 1.) @@ -54,31 +107,12 @@ def iterate_dataset(data_loader): (rec_loss + alpha * cmt_loss).backward() opt.step() + pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"cmt loss: {cmt_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"cmt loss: {cmt_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = SimpleVQAutoEncoder( - codebook_size = num_codes, - rotation_trick = False, - straight_through = False, - directional_reparam = True -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py index 56e3c6e..e418814 100644 --- a/examples/autoencoder_fsq.py +++ b/examples/autoencoder_fsq.py @@ -1,44 +1,81 @@ -# FashionMnist VQ experiment with various settings, using FSQ. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange import math +import fire import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import FSQ, Sequential +# helpers -lr = 3e-4 -train_iter = 1000 -levels = [8, 6, 5] # target size 2^8, actual size 240 -num_codes = math.prod(levels) -seed = 1234 -device = "cuda" if torch.cuda.is_available() else "cpu" +def exists(val): + return val is not None +def default(val, d): + return val if exists(val) else d + +# classes def SimpleFSQAutoEncoder(levels: list[int]): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(32, len(levels), kernel_size=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(32, len(levels), kernel_size = 1), FSQ(levels), - nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(len(levels), 32, kernel_size = 3, stride = 1, padding = 1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), ) +def train( + train_iter = 1000, + lr = 3e-4, + levels = [8, 6, 5], + seed = 1234, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + num_codes = math.prod(levels) + + model = SimpleFSQAutoEncoder(levels).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, + ) -def train(model, train_loader, train_iterations=1000): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -49,9 +86,14 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) + out, indices = model(x) out = out.clamp(-1., 1.) @@ -59,24 +101,11 @@ def iterate_dataset(data_loader): rec_loss.backward() opt.step() + pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) - -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) -model = SimpleFSQAutoEncoder(levels).to(device) -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_fvq.py b/examples/autoencoder_fvq.py new file mode 100644 index 0000000..c0c4f7c --- /dev/null +++ b/examples/autoencoder_fvq.py @@ -0,0 +1,197 @@ +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# "x-transformers==2.16.1", +# ] +# /// + +from tqdm.auto import trange + +import fire + +import torch +import torch.nn as nn +from torch.nn import Module +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from torch.optim import SGD, AdamW + +from einops import rearrange +from einops.layers.torch import Rearrange + +from x_transformers import ContinuousTransformerWrapper, Encoder +from vector_quantize_pytorch import VectorQuantize, Sequential + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes + +def VQBridgeViT( + dim, + depth, + input_dim = None, + patch_size = 1, + dim_head = 64, + heads = 4, + num_registers = 0 +): + # Credit goes to Mahdi (@mahdip72) for his experiments that found the best + # set of hyperparameters for the ViT used in FVQ, which is patch_size 1 (becomes a regular transformer encoder) and critically, having register tokens (we will do 2 here) + # see experiments at https://github.com/lucidrains/vector-quantize-pytorch/issues/239#issuecomment-3888240360 + + input_dim = default(input_dim, dim) + + project_in_out_kwargs = dict() + + inner_dim = input_dim * patch_size + + if patch_size > 1 or inner_dim != dim: + project_in_out_kwargs.update( + project_in = nn.Sequential( + Rearrange('b (n p) c -> b n (p c)', p = patch_size), + nn.Linear(inner_dim, dim, bias = False) + ), + project_out = nn.Sequential( + nn.Linear(dim, inner_dim, bias = False), + Rearrange('b n (p c) -> b (n p) c', p = patch_size) + ) + ) + + return ContinuousTransformerWrapper( + num_memory_tokens = num_registers, + attn_layers = Encoder( + dim = dim, + depth = depth, + heads = heads, + attn_dim_head = dim_head, + pre_norm_has_final_norm = False + ), + **project_in_out_kwargs + ) + +def SimpleVQAutoEncoder( + dim = 32, + vq_bridge: Module | None = None, + rotation_trick = True, + **vq_kwargs +): + return Sequential( + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + VectorQuantize( + dim = dim, + accept_image_fmap = True, + rotation_trick = rotation_trick, + vq_bridge = vq_bridge, + **vq_kwargs + ), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), + nn.GELU(), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + patch_size = 1, + no_bridge = False, + rotation_trick = False, + num_registers = 2, + heads = 4, + vq_dim = 256, + vq_depth = 1, + alpha = 10, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + vq_bridge = None + + if not no_bridge: + vq_bridge = VQBridgeViT( + dim = vq_dim, + depth = vq_depth, + input_dim = dim, + patch_size = patch_size, + heads = heads, + num_registers = num_registers + ) + + model = SimpleVQAutoEncoder( + dim = dim, + vq_bridge = vq_bridge, + rotation_trick = rotation_trick + codebook_size = num_codes, + learnable_codebook = True, + in_place_codebook_optimizer = lambda *args, **kwargs: SGD(*args, **kwargs, lr = 1e-3), + ema_update = False, + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, + ) + + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: + opt.zero_grad() + x, _ = next(dl_iter) + + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + + rec_loss = (out - x).abs().mean() + (rec_loss + alpha * cmt_loss).backward() + + opt.step() + + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + f"cmt loss: {cmt_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_lfq.py b/examples/autoencoder_lfq.py index 7ffdb7c..4c187fe 100644 --- a/examples/autoencoder_lfq.py +++ b/examples/autoencoder_lfq.py @@ -1,26 +1,37 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange from math import log2 +import fire import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW from vector_quantize_pytorch import LFQ, Sequential -lr = 3e-4 -train_iter = 1000 -seed = 1234 -codebook_size = 2 ** 8 -entropy_loss_weight = 0.02 -diversity_gamma = 1. -spherical = True +# helpers -device = "cuda" if torch.cuda.is_available() else "cpu" +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# classes def LFQAutoEncoder( codebook_size, @@ -30,25 +41,55 @@ def LFQAutoEncoder( quantize_dim = int(log2(codebook_size)) return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - # In general norm layers are commonly used in Resnet-based encoder/decoders - # explicitly add one here with affine=False to avoid introducing new parameters - nn.GroupNorm(4, 32, affine=False), - nn.Conv2d(32, quantize_dim, kernel_size=1), - LFQ(dim=quantize_dim, **vq_kwargs), - nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.GroupNorm(4, 32, affine = False), + nn.Conv2d(32, quantize_dim, kernel_size = 1), + LFQ(dim = quantize_dim, **vq_kwargs), + nn.Conv2d(quantize_dim, 32, kernel_size = 3, stride = 1, padding = 1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + codebook_size = 256, + seed = 1234, + entropy_loss_weight = 0.02, + diversity_gamma = 1., + spherical = True, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = LFQAutoEncoder( + codebook_size = codebook_size, + entropy_loss_weight = entropy_loss_weight, + diversity_gamma = diversity_gamma, + spherical = spherical + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -59,9 +100,14 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) + out, indices, entropy_aux_loss = model(x) out = out.clamp(-1., 1.) @@ -69,33 +115,12 @@ def iterate_dataset(data_loader): (rec_loss + entropy_aux_loss).backward() opt.step() + pbar.set_description( - f"rec loss: {rec_loss.item():.3f} | " - + f"entropy aux loss: {entropy_aux_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}" + f"rec loss: {rec_loss.item():.3f} | " + f"entropy aux loss: {entropy_aux_loss.item():.3f} | " + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) - -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = LFQAutoEncoder( - codebook_size = codebook_size, - entropy_loss_weight = entropy_loss_weight, - diversity_gamma = diversity_gamma, - spherical = spherical -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) - -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index 535b96d..4c82a5d 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -1,5 +1,14 @@ -# FashionMnist VQ experiment with various settings. -# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py +#!/usr/bin/env uv run +# /// script +# dependencies = [ +# "torch", +# "torchvision", +# "tqdm", +# "fire", +# "einops", +# "einx", +# ] +# /// from tqdm.auto import trange @@ -7,35 +16,73 @@ import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torch.optim import AdamW +import fire from vector_quantize_pytorch import SimVQ, Sequential -lr = 3e-4 -train_iter = 10000 -num_codes = 256 -seed = 1234 +# helpers -rotation_trick = True # rotation trick instead ot straight-through -use_mlp = True # use a one layer mlp with relu instead of linear +def exists(val): + return val is not None -device = "cuda" if torch.cuda.is_available() else "cpu" +def default(val, d): + return val if exists(val) else d -def SimVQAutoEncoder(**vq_kwargs): +# classes + +def SimpleSimVQAutoEncoder(dim = 32, **vq_kwargs): return Sequential( - nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), nn.GELU(), - nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), - nn.MaxPool2d(kernel_size=2, stride=2), - SimVQ(dim=32, channel_first = True, **vq_kwargs), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1), + nn.MaxPool2d(kernel_size = 2, stride = 2), + SimVQ(dim = dim, channel_first = True, **vq_kwargs), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1), nn.GELU(), - nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor = 2, mode = "nearest"), + nn.Conv2d(16, 1, kernel_size = 3, stride = 1, padding = 1), + ) + +def train( + train_iter = 1000, + lr = 3e-4, + dim = 32, + num_codes = 256, + seed = 1234, + rotation_trick = True, + use_mlp = True, + batch_size = 256 +): + torch.random.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = SimpleSimVQAutoEncoder( + dim = dim, + codebook_size = num_codes, + rotation_trick = rotation_trick, + codebook_transform = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.ReLU(), + nn.Linear(dim * 4, dim), + ) if use_mlp else None + ).to(device) + + opt = AdamW(model.parameters(), lr = lr) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_dataset = DataLoader( + datasets.FashionMNIST(root = "~/data/fashion_mnist", train = True, download = True, transform = transform), + batch_size = batch_size, + shuffle = True, ) -def train(model, train_loader, train_iterations=1000, alpha=10): def iterate_dataset(data_loader): data_iter = iter(data_loader) while True: @@ -46,47 +93,27 @@ def iterate_dataset(data_loader): x, y = next(data_iter) yield x.to(device), y.to(device) - for _ in (pbar := trange(train_iterations)): + dl_iter = iterate_dataset(train_dataset) + + pbar = trange(train_iter) + + for _ in pbar: opt.zero_grad() - x, _ = next(iterate_dataset(train_loader)) + x, _ = next(dl_iter) - out, indices, cmt_loss = model(x) + out, indices, sim_loss = model(x) out = out.clamp(-1., 1.) rec_loss = (out - x).abs().mean() - (rec_loss + alpha * cmt_loss).backward() + (rec_loss + sim_loss).backward() opt.step() pbar.set_description( f"rec loss: {rec_loss.item():.3f} | " - + f"cmt loss: {cmt_loss.item():.3f} | " - + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + f"sim loss: {sim_loss.item():.3f} | " + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" ) -transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] -) - -train_dataset = DataLoader( - datasets.FashionMNIST( - root="~/data/fashion_mnist", train=True, download=True, transform=transform - ), - batch_size=256, - shuffle=True, -) - -torch.random.manual_seed(seed) - -model = SimVQAutoEncoder( - codebook_size = num_codes, - rotation_trick = rotation_trick, - codebook_transform = nn.Sequential( - nn.Linear(32, 128), - nn.ReLU(), - nn.Linear(128, 32), - ) if use_mlp else None -).to(device) - -opt = torch.optim.AdamW(model.parameters(), lr=lr) -train(model, train_dataset, train_iterations=train_iter) +if __name__ == "__main__": + fire.Fire(train) diff --git a/pyproject.toml b/pyproject.toml index 3d819b6..93ff58f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.27.15" +version = "1.27.16" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } From a4fa9990a280da437ac3605303513fe4a5a045a2 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 12 Feb 2026 08:58:43 -0800 Subject: [PATCH 3/4] make sure to assert out invalid config --- pyproject.toml | 4 +- tests/test_lfq.py | 74 +++++++++---------- tests/test_readme.py | 28 +++++++ .../vector_quantize_pytorch.py | 6 +- 4 files changed, 72 insertions(+), 40 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e56df9..b8939d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,15 +41,17 @@ build-backend = "hatchling.build" [tool.rye] managed = true + dev-dependencies = [ "ruff>=0.4.2", "pytest>=8.2.0", "pytest-cov>=5.0.0", + "x-transformers>=2.16.1", ] [tool.pytest.ini_options] pythonpath = [ - "." + "." ] [tool.hatch.metadata] diff --git a/tests/test_lfq.py b/tests/test_lfq.py index 0eb5903..32ccfc4 100644 --- a/tests/test_lfq.py +++ b/tests/test_lfq.py @@ -1,36 +1,36 @@ import torch import pytest -from vector_quantize_pytorch import LFQ import math -""" -testing_strategy: -subdivisions: using masks, using frac_per_sample_entropy < 1 -""" +from vector_quantize_pytorch import LFQ + +# helpers -torch.manual_seed(0) +def exists(val): + return val is not None + +# tests @pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5)) -@pytest.mark.parametrize('mask', (torch.tensor([False, False]), - torch.tensor([True, False]), - torch.tensor([True, True]))) +@pytest.mark.parametrize('mask', ( + torch.tensor([False, False]), + torch.tensor([True, False]), + torch.tensor([True, True]) +)) def test_masked_lfq( frac_per_sample_entropy, mask ): - # you can specify either dim or codebook_size - # if both specified, will be validated against each other - quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = frac_per_sample_entropy ) image_feats = torch.randn(2, 16, 32, 32) - ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + ret, _ = quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) quantized, indices, _ = ret assert (quantized == quantizer.indices_to_codes(indices)).all() @@ -38,40 +38,40 @@ def test_masked_lfq( @pytest.mark.parametrize('frac_per_sample_entropy', (0.1,)) @pytest.mark.parametrize('iters', (10,)) @pytest.mark.parametrize('mask', (None, torch.tensor([True, False]))) -def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask): +def test_lfq_bruteforce_frac_per_sample_entropy( + frac_per_sample_entropy, + iters, + mask +): image_feats = torch.randn(2, 16, 32, 32) full_per_sample_entropy_quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = 1 ) partial_per_sample_entropy_quantizer = LFQ( - codebook_size = 65536, # codebook size, must be a power of 2 - dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined - entropy_loss_weight = 0.1, # how much weight to place on entropy loss - diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + codebook_size = 65536, + dim = 16, + entropy_loss_weight = 0.1, + diversity_gamma = 1., frac_per_sample_entropy = frac_per_sample_entropy ) - ret, loss_breakdown = full_per_sample_entropy_quantizer( - image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) + ret, loss_breakdown = full_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) true_per_sample_entropy = loss_breakdown.per_sample_entropy per_sample_losses = torch.zeros(iters) - for iter in range(iters): - ret, loss_breakdown = partial_per_sample_entropy_quantizer( - image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + + for i in range(iters): + ret, loss_breakdown = partial_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask) quantized, indices, _ = ret assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all() - per_sample_losses[iter] = loss_breakdown.per_sample_entropy - # 95% confidence interval - assert abs(per_sample_losses.mean() - true_per_sample_entropy) \ - < (1.96*(per_sample_losses.std() / math.sqrt(iters))) + per_sample_losses[i] = loss_breakdown.per_sample_entropy - print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy)) - print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters)))) \ No newline at end of file + # 95% confidence interval + assert abs(per_sample_losses.mean() - true_per_sample_entropy) < (1.96 * (per_sample_losses.std() / math.sqrt(iters))) \ No newline at end of file diff --git a/tests/test_readme.py b/tests/test_readme.py index f861efe..9ff53e4 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -506,3 +506,31 @@ def test_vq_3d(): assert x.shape == quantized.shape assert indices.shape == (1, 16, 16, 16) assert torch.allclose(quantized, quantizer.get_output_from_indices(indices)) + +def test_fvq(): + from vector_quantize_pytorch import VectorQuantize + from x_transformers import ContinuousTransformerWrapper, Encoder + + vq_bridge = ContinuousTransformerWrapper( + attn_layers = Encoder( + dim = 256, + depth = 1, + heads = 4, + pre_norm_has_final_norm = False + ) + ) + + vq = VectorQuantize( + dim = 256, + codebook_size = 512, + vq_bridge = vq_bridge, + learnable_codebook = True, + ema_update = False, + in_place_codebook_optimizer = lambda params: torch.optim.SGD(params, lr = 1e-3) + ) + + x = torch.randn(1, 1024, 256) + quantized, indices, commit_loss = vq(x) + + assert quantized.shape == x.shape + assert indices.shape == (1, 1024) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 2baa83e..dddcdb4 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -813,8 +813,8 @@ def __init__( # defaults - ema_update = default(ema_update, not directional_reparam) - learnable_codebook = default(learnable_codebook, directional_reparam) + ema_update = default(ema_update, not directional_reparam and not exists(vq_bridge)) + learnable_codebook = default(learnable_codebook, directional_reparam or exists(vq_bridge)) rotation_trick = default(rotation_trick, not directional_reparam and dim > 1) # only use rotation trick if feature dimension greater than 1 # basic variables @@ -865,6 +865,8 @@ def __init__( assert not (straight_through and learnable_codebook), 'gumbel straight through not allowed when learning the codebook' assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update' + assert not (exists(vq_bridge) and not learnable_codebook), 'learnable_codebook must be set to True if vq_bridge is passed in' + assert not (exists(vq_bridge) and ema_update), 'ema_update must be False if vq_bridge is passed in' assert 0 <= sync_update_v <= 1. assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on' From 8c3360e1e81a65b2fe87c108449c946097b7d865 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 12 Feb 2026 09:01:11 -0800 Subject: [PATCH 4/4] make sure tests is the optimal vq bridge --- tests/test_readme.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_readme.py b/tests/test_readme.py index 9ff53e4..4d5fe1f 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -512,6 +512,7 @@ def test_fvq(): from x_transformers import ContinuousTransformerWrapper, Encoder vq_bridge = ContinuousTransformerWrapper( + num_memory_tokens = 2, attn_layers = Encoder( dim = 256, depth = 1,