@@ -76,7 +76,8 @@ def __init__(
7676 force_quantization_f32 = True ,
7777 preserve_symmetry = False ,
7878 noise_dropout = 0. ,
79- bound_hard_clamp = False # for residual fsq, if input is pre-softclamped to the right range
79+ bound_hard_clamp = False , # for residual fsq, if input is pre-softclamped to the right range
80+ orthogonal_rotation = False # increase codebook utilization. ensure levels are symmetric! https://arxiv.org/abs/2307.13304v2
8081 ):
8182 super ().__init__ ()
8283
@@ -132,6 +133,17 @@ def __init__(
132133
133134 self .bound_hard_clamp = bound_hard_clamp
134135
136+ self .orthogonal_rotation = orthogonal_rotation
137+
138+ if orthogonal_rotation :
139+ is_symmetric = len (set (levels )) == 1
140+ if not is_symmetric :
141+ print ('orthogonal_rotation is not recommended for FSQ with asymmetric levels (i.e. where the number of bins differ across dimensions)' )
142+
143+ orthogonal_rot = torch .empty (codebook_dim , codebook_dim )
144+ nn .init .orthogonal_ (orthogonal_rot )
145+ self .register_buffer ('orthogonal_rot' , orthogonal_rot )
146+
135147 def bound (self , z , eps = 1e-3 , hard_clamp = False ):
136148 """ Bound `z`, an array of shape (..., d). """
137149 maybe_tanh = tanh if not hard_clamp else partial (clamp , min = - 1. , max = 1. )
@@ -219,6 +231,9 @@ def indices_to_codes(self, indices):
219231
220232 codes = self ._indices_to_codes (indices )
221233
234+ if self .orthogonal_rotation :
235+ codes = codes @ self .orthogonal_rot .t ()
236+
222237 if self .keep_num_codebooks_dim :
223238 codes = rearrange (codes , '... c d -> ... (c d)' )
224239
@@ -253,6 +268,9 @@ def forward(self, z):
253268
254269 z = rearrange (z , 'b n (c d) -> b n c d' , c = self .num_codebooks )
255270
271+ if self .orthogonal_rotation :
272+ z = z @ self .orthogonal_rot
273+
256274 # whether to force quantization step to be full precision or not
257275
258276 force_f32 = self .force_quantization_f32
@@ -275,6 +293,9 @@ def forward(self, z):
275293
276294 codes = self .maybe_apply_noise (codes )
277295
296+ if self .orthogonal_rotation :
297+ codes = codes @ self .orthogonal_rot .t ()
298+
278299 codes = rearrange (codes , 'b n c d -> b n (c d)' )
279300
280301 codes = codes .to (orig_dtype )
0 commit comments