Skip to content

Commit cbf4454

Browse files
authored
Merge branch 'master' into test-fvq
2 parents 0db73ad + 7a4e13c commit cbf4454

3 files changed

Lines changed: 25 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.16"
3+
version = "1.27.21"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/binary_mapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(
7777
calc_aux_loss = None,
7878
deterministic = None,
7979
return_indices = False,
80+
reduce_aux_kl_loss = True
8081
):
8182
deterministic = default(deterministic, self.deterministic_on_eval and not self.training)
8283

@@ -112,7 +113,14 @@ def forward(
112113
# calculate negative entropy
113114

114115
kl_div = self.bits * NAT - binary_entropy(logits)
115-
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
116+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold)
117+
118+
# able to return unreduced kl loss, for use in another project (metacontroller)
119+
120+
if reduce_aux_kl_loss:
121+
aux_kl_loss = aux_kl_loss.mean()
122+
else:
123+
aux_kl_loss = inverse_pack_lead_dims(aux_kl_loss, '*')
116124

117125
# maybe straight through
118126

@@ -150,11 +158,11 @@ def forward(
150158

151159
logits = torch.randn(3, 4, 8)
152160

153-
sparse_one_hot, indices, aux_loss = binary_mapper(logits, return_indices = True)
161+
sparse_one_hot, indices, aux_loss = binary_mapper(logits, return_indices = True, reduce_aux_kl_loss = False)
154162

155163
assert sparse_one_hot.shape == (3, 4, 2 ** 8)
156164
assert indices.shape == (3, 4)
157-
assert aux_loss.numel() == 1
165+
assert aux_loss.shape == (3, 4)
158166

159167
binary_mapper.eval()
160168
sparse_one_hot1, _ = binary_mapper(logits, deterministic = True)

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393

9494
self.scale = scale
9595

96+
assert not (noise_dropout > 0 and not preserve_symmetry)
9697
self.preserve_symmetry = preserve_symmetry
9798
self.noise_dropout = noise_dropout
9899

@@ -158,22 +159,26 @@ def symmetry_preserving_bound(self, z, hard_clamp = False):
158159
def quantize(self, z):
159160
""" Quantizes z, returns quantized zhat, same shape as z. """
160161

161-
shape, device, noise_dropout, preserve_symmetry = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry
162+
shape, device, preserve_symmetry = z.shape[0], z.device, self.preserve_symmetry
162163
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
163164

164-
bounded_z = bound_fn(z, hard_clamp = self.bound_hard_clamp)
165+
return bound_fn(z, hard_clamp = self.bound_hard_clamp)
165166

166-
# determine where to add a random offset elementwise
167-
# if using noise dropout
167+
def maybe_apply_noise(self, bounded_z):
168+
noise_dropout = self.noise_dropout
168169

169170
if not self.training or noise_dropout == 0.:
170171
return bounded_z
171172

172-
offset_mask = torch.bernoulli(torch.full_like(bounded_z, noise_dropout)).bool()
173+
# determine where to add a random offset elementwise
174+
# if using noise dropout
175+
176+
offset_mask = torch.full_like(bounded_z, noise_dropout).bernoulli_().bool()
173177
offset = torch.rand_like(bounded_z) - 0.5
178+
174179
bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z)
175180

176-
return bounded_z
181+
return bounded_z.clamp(-1., 1.)
177182

178183
def _scale_and_shift(self, zhat_normalized):
179184
if self.preserve_symmetry:
@@ -268,6 +273,8 @@ def forward(self, z):
268273
if self.return_indices:
269274
indices = self.codes_to_indices(codes)
270275

276+
codes = self.maybe_apply_noise(codes)
277+
271278
codes = rearrange(codes, 'b n c d -> b n (c d)')
272279

273280
codes = codes.to(orig_dtype)

0 commit comments

Comments
 (0)