We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 949b0ba commit 3bb00f5Copy full SHA for 3bb00f5
1 file changed
examples/autoencoder_sim_vq.py
@@ -14,7 +14,10 @@
14
train_iter = 10000
15
num_codes = 256
16
seed = 1234
17
-rotation_trick = True
+
18
+rotation_trick = True # rotation trick instead ot straight-through
19
+use_mlp = True # use a one layer mlp with relu instead of linear
20
21
device = "cuda" if torch.cuda.is_available() else "cpu"
22
23
def SimVQAutoEncoder(**vq_kwargs):
@@ -77,7 +80,12 @@ def iterate_dataset(data_loader):
77
80
78
81
model = SimVQAutoEncoder(
79
82
codebook_size = num_codes,
- rotation_trick = rotation_trick
83
+ rotation_trick = rotation_trick,
84
+ codebook_transform = nn.Sequential(
85
+ nn.Linear(32, 128),
86
+ nn.ReLU(),
87
+ nn.Linear(128, 32),
88
+ ) if use_mlp else None
89
).to(device)
90
91
opt = torch.optim.AdamW(model.parameters(), lr=lr)
0 commit comments