Skip to content

Commit fb78d77

Browse files
committed
move application of noise when noise dropout is on for fsq
1 parent 008c07b commit fb78d77

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.16"
3+
version = "1.27.19"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,20 @@ def symmetry_preserving_bound(self, z, hard_clamp = False):
159159
def quantize(self, z):
160160
""" Quantizes z, returns quantized zhat, same shape as z. """
161161

162-
shape, device, noise_dropout, preserve_symmetry = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry
162+
shape, device, preserve_symmetry = z.shape[0], z.device, self.preserve_symmetry
163163
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
164164

165-
bounded_z = bound_fn(z, hard_clamp = self.bound_hard_clamp)
165+
return bound_fn(z, hard_clamp = self.bound_hard_clamp)
166166

167-
# determine where to add a random offset elementwise
168-
# if using noise dropout
167+
def maybe_apply_noise(self, bounded_z):
168+
noise_dropout = self.noise_dropout
169169

170170
if not self.training or noise_dropout == 0.:
171171
return bounded_z
172172

173+
# determine where to add a random offset elementwise
174+
# if using noise dropout
175+
173176
offset_mask = torch.full_like(bounded_z, noise_dropout).bernoulli_().bool()
174177
offset = torch.rand_like(bounded_z) - 0.5
175178

@@ -270,6 +273,8 @@ def forward(self, z):
270273
if self.return_indices:
271274
indices = self.codes_to_indices(codes)
272275

276+
codes = self.maybe_apply_noise(codes)
277+
273278
codes = rearrange(codes, 'b n c d -> b n (c d)')
274279

275280
codes = codes.to(orig_dtype)

0 commit comments

Comments
 (0)