diff --git a/examples/autoencoder_fvq.py b/examples/autoencoder_fvq.py index c0c4f7c..6f53a0e 100644 --- a/examples/autoencoder_fvq.py +++ b/examples/autoencoder_fvq.py @@ -141,7 +141,7 @@ def train( model = SimpleVQAutoEncoder( dim = dim, vq_bridge = vq_bridge, - rotation_trick = rotation_trick + rotation_trick = rotation_trick, codebook_size = num_codes, learnable_codebook = True, in_place_codebook_optimizer = lambda *args, **kwargs: SGD(*args, **kwargs, lr = 1e-3),