@@ -159,17 +159,20 @@ def symmetry_preserving_bound(self, z, hard_clamp = False):
159159 def quantize (self , z ):
160160 """ Quantizes z, returns quantized zhat, same shape as z. """
161161
162- shape , device , noise_dropout , preserve_symmetry = z .shape [0 ], z .device , self . noise_dropout , self .preserve_symmetry
162+ shape , device , preserve_symmetry = z .shape [0 ], z .device , self .preserve_symmetry
163163 bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
164164
165- bounded_z = bound_fn (z , hard_clamp = self .bound_hard_clamp )
165+ return bound_fn (z , hard_clamp = self .bound_hard_clamp )
166166
167- # determine where to add a random offset elementwise
168- # if using noise dropout
167+ def maybe_apply_noise ( self , bounded_z ):
168+ noise_dropout = self . noise_dropout
169169
170170 if not self .training or noise_dropout == 0. :
171171 return bounded_z
172172
173+ # determine where to add a random offset elementwise
174+ # if using noise dropout
175+
173176 offset_mask = torch .full_like (bounded_z , noise_dropout ).bernoulli_ ().bool ()
174177 offset = torch .rand_like (bounded_z ) - 0.5
175178
@@ -270,6 +273,8 @@ def forward(self, z):
270273 if self .return_indices :
271274 indices = self .codes_to_indices (codes )
272275
276+ codes = self .maybe_apply_noise (codes )
277+
273278 codes = rearrange (codes , 'b n c d -> b n (c d)' )
274279
275280 codes = codes .to (orig_dtype )
0 commit comments