1- # FashionMnist VQ experiment with various settings.
2- # From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
1+ #!/usr/bin/env uv run
2+ # /// script
3+ # dependencies = [
4+ # "torch",
5+ # "torchvision",
6+ # "tqdm",
7+ # "fire",
8+ # "einops",
9+ # "einx",
10+ # ]
11+ # ///
312
413from tqdm .auto import trange
514
15+ import fire
616import torch
717import torch .nn as nn
818from torchvision import datasets , transforms
919from torch .utils .data import DataLoader
20+ from torch .optim import AdamW
1021
1122from vector_quantize_pytorch import VectorQuantize , Sequential
1223
13- lr = 3e-4
14- train_iter = 1000
15- num_codes = 256
16- seed = 1234
17- rotation_trick = True
18- device = "cuda" if torch .cuda .is_available () else "cpu"
24+ # helpers
1925
20- def SimpleVQAutoEncoder (** vq_kwargs ):
26+ def exists (val ):
27+ return val is not None
28+
29+ def default (val , d ):
30+ return val if exists (val ) else d
31+
32+ # classes
33+
34+ def SimpleVQAutoEncoder (dim = 32 , ** vq_kwargs ):
2135 return Sequential (
22- nn .Conv2d (1 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
23- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
36+ nn .Conv2d (1 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
37+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
2438 nn .GELU (),
25- nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
26- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
27- VectorQuantize (dim = 32 , accept_image_fmap = True , ** vq_kwargs ),
28- nn .Upsample (scale_factor = 2 , mode = "nearest" ),
29- nn .Conv2d (32 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
39+ nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
40+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
41+ VectorQuantize (dim = dim , accept_image_fmap = True , ** vq_kwargs ),
42+ nn .Upsample (scale_factor = 2 , mode = "nearest" ),
43+ nn .Conv2d (32 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
3044 nn .GELU (),
31- nn .Upsample (scale_factor = 2 , mode = "nearest" ),
32- nn .Conv2d (16 , 1 , kernel_size = 3 , stride = 1 , padding = 1 ),
45+ nn .Upsample (scale_factor = 2 , mode = "nearest" ),
46+ nn .Conv2d (16 , 1 , kernel_size = 3 , stride = 1 , padding = 1 ),
47+ )
48+
49+ def train (
50+ train_iter = 1000 ,
51+ lr = 3e-4 ,
52+ dim = 32 ,
53+ num_codes = 256 ,
54+ seed = 1234 ,
55+ rotation_trick = True ,
56+ straight_through = False ,
57+ directional_reparam = False ,
58+ alpha = 10 ,
59+ batch_size = 256
60+ ):
61+ torch .random .manual_seed (seed )
62+ device = "cuda" if torch .cuda .is_available () else "cpu"
63+
64+ model = SimpleVQAutoEncoder (
65+ dim = dim ,
66+ codebook_size = num_codes ,
67+ rotation_trick = rotation_trick ,
68+ straight_through = straight_through ,
69+ directional_reparam = directional_reparam
70+ ).to (device )
71+
72+ opt = AdamW (model .parameters (), lr = lr )
73+
74+ transform = transforms .Compose ([
75+ transforms .ToTensor (),
76+ transforms .Normalize ((0.5 ,), (0.5 ,))
77+ ])
78+
79+ train_dataset = DataLoader (
80+ datasets .FashionMNIST (root = "~/data/fashion_mnist" , train = True , download = True , transform = transform ),
81+ batch_size = batch_size ,
82+ shuffle = True ,
3383 )
3484
35- def train (model , train_loader , train_iterations = 1000 , alpha = 10 ):
3685 def iterate_dataset (data_loader ):
3786 data_iter = iter (data_loader )
3887 while True :
@@ -43,9 +92,13 @@ def iterate_dataset(data_loader):
4392 x , y = next (data_iter )
4493 yield x .to (device ), y .to (device )
4594
46- for _ in (pbar := trange (train_iterations )):
95+ dl_iter = iterate_dataset (train_dataset )
96+
97+ pbar = trange (train_iter )
98+
99+ for _ in pbar :
47100 opt .zero_grad ()
48- x , _ = next (iterate_dataset ( train_loader ) )
101+ x , _ = next (dl_iter )
49102
50103 out , indices , cmt_loss = model (x )
51104 out = out .clamp (- 1. , 1. )
@@ -54,31 +107,12 @@ def iterate_dataset(data_loader):
54107 (rec_loss + alpha * cmt_loss ).backward ()
55108
56109 opt .step ()
110+
57111 pbar .set_description (
58112 f"rec loss: { rec_loss .item ():.3f} | "
59- + f"cmt loss: { cmt_loss .item ():.3f} | "
60- + f"active %: { indices .unique ().numel () / num_codes * 100 :.3f} "
113+ f"cmt loss: { cmt_loss .item ():.3f} | "
114+ f"active %: { indices .unique ().numel () / num_codes * 100 :.3f} "
61115 )
62116
63- transform = transforms .Compose (
64- [transforms .ToTensor (), transforms .Normalize ((0.5 ,), (0.5 ,))]
65- )
66- train_dataset = DataLoader (
67- datasets .FashionMNIST (
68- root = "~/data/fashion_mnist" , train = True , download = True , transform = transform
69- ),
70- batch_size = 256 ,
71- shuffle = True ,
72- )
73-
74- torch .random .manual_seed (seed )
75-
76- model = SimpleVQAutoEncoder (
77- codebook_size = num_codes ,
78- rotation_trick = False ,
79- straight_through = False ,
80- directional_reparam = True
81- ).to (device )
82-
83- opt = torch .optim .AdamW (model .parameters (), lr = lr )
84- train (model , train_dataset , train_iterations = train_iter )
117+ if __name__ == "__main__" :
118+ fire .Fire (train )
0 commit comments