@@ -93,6 +93,7 @@ def __init__(
9393
9494 self .scale = scale
9595
96+ assert not (noise_dropout > 0 and not preserve_symmetry )
9697 self .preserve_symmetry = preserve_symmetry
9798 self .noise_dropout = noise_dropout
9899
@@ -158,22 +159,26 @@ def symmetry_preserving_bound(self, z, hard_clamp = False):
158159 def quantize (self , z ):
159160 """ Quantizes z, returns quantized zhat, same shape as z. """
160161
161- shape , device , noise_dropout , preserve_symmetry = z .shape [0 ], z .device , self . noise_dropout , self .preserve_symmetry
162+ shape , device , preserve_symmetry = z .shape [0 ], z .device , self .preserve_symmetry
162163 bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
163164
164- bounded_z = bound_fn (z , hard_clamp = self .bound_hard_clamp )
165+ return bound_fn (z , hard_clamp = self .bound_hard_clamp )
165166
166- # determine where to add a random offset elementwise
167- # if using noise dropout
167+ def maybe_apply_noise ( self , bounded_z ):
168+ noise_dropout = self . noise_dropout
168169
169170 if not self .training or noise_dropout == 0. :
170171 return bounded_z
171172
172- offset_mask = torch .bernoulli (torch .full_like (bounded_z , noise_dropout )).bool ()
173+ # determine where to add a random offset elementwise
174+ # if using noise dropout
175+
176+ offset_mask = torch .full_like (bounded_z , noise_dropout ).bernoulli_ ().bool ()
173177 offset = torch .rand_like (bounded_z ) - 0.5
178+
174179 bounded_z = torch .where (offset_mask , bounded_z + offset , bounded_z )
175180
176- return bounded_z
181+ return bounded_z . clamp ( - 1. , 1. )
177182
178183 def _scale_and_shift (self , zhat_normalized ):
179184 if self .preserve_symmetry :
@@ -268,6 +273,8 @@ def forward(self, z):
268273 if self .return_indices :
269274 indices = self .codes_to_indices (codes )
270275
276+ codes = self .maybe_apply_noise (codes )
277+
271278 codes = rearrange (codes , 'b n c d -> b n (c d)' )
272279
273280 codes = codes .to (orig_dtype )
0 commit comments