Skip to content

Commit d8f4eb2

Browse files
committed
address #240
1 parent 8a25e16 commit d8f4eb2

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393

9494
self.scale = scale
9595

96+
assert not (noise_dropout > 0 and not preserve_symmetry)
9697
self.preserve_symmetry = preserve_symmetry
9798
self.noise_dropout = noise_dropout
9899

@@ -169,11 +170,12 @@ def quantize(self, z):
169170
if not self.training or noise_dropout == 0.:
170171
return bounded_z
171172

172-
offset_mask = torch.bernoulli(torch.full_like(bounded_z, noise_dropout)).bool()
173+
offset_mask = torch.full_like(bounded_z, noise_dropout).bernoulli_().bool()
173174
offset = torch.rand_like(bounded_z) - 0.5
175+
174176
bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z)
175177

176-
return bounded_z
178+
return bounded_z.clamp(-1., 1.)
177179

178180
def _scale_and_shift(self, zhat_normalized):
179181
if self.preserve_symmetry:

0 commit comments

Comments
 (0)