Skip to content

Commit f4ff5d1

Browse files
committed
make binary mapper policy optimizable
1 parent d56cbed commit f4ff5d1

11 files changed

Lines changed: 81 additions & 33 deletions

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ xhat, indices = quantizer(x)
353353
assert torch.all(xhat == quantizer.indices_to_codes(indices))
354354
```
355355

356-
An improvised Residual FSQ, for an attempt to improve audio encoding.
356+
An improvised Residual FSQ, for an attempt to improve audio encoding.
357357

358358
Credit goes to [@sekstini](https://github.com/sekstini) for originally incepting the idea [here](https://github.com/lucidrains/vector-quantize-pytorch/pull/74#issuecomment-1742048597)
359359

@@ -506,7 +506,7 @@ from vector_quantize_pytorch import LatentQuantize
506506
quantizer = LatentQuantize(
507507
levels = [5, 5, 8], # number of levels per codebook dimension
508508
dim = 16, # input dim
509-
commitment_loss_weight=0.1,
509+
commitment_loss_weight=0.1,
510510
quantization_loss_weight=0.1,
511511
)
512512

@@ -530,7 +530,7 @@ from vector_quantize_pytorch import LatentQuantize
530530
quantizer = LatentQuantize(
531531
levels = [5, 5, 8],
532532
dim = 16,
533-
commitment_loss_weight=0.1,
533+
commitment_loss_weight=0.1,
534534
quantization_loss_weight=0.1,
535535
)
536536

@@ -720,7 +720,7 @@ assert loss.item() >= 0
720720

721721
```bibtex
722722
@misc{hsu2023disentanglement,
723-
title = {Disentanglement via Latent Quantization},
723+
title = {Disentanglement via Latent Quantization},
724724
author = {Kyle Hsu and Will Dorrell and James C. R. Whittington and Jiajun Wu and Chelsea Finn},
725725
year = {2023},
726726
eprint = {2305.18378},
@@ -782,36 +782,36 @@ assert loss.item() >= 0
782782

783783
```bibtex
784784
@misc{vali2025diveqdifferentiablevectorquantization,
785-
title = {DiVeQ: Differentiable Vector Quantization Using the Reparameterization Trick},
785+
title = {DiVeQ: Differentiable Vector Quantization Using the Reparameterization Trick},
786786
author = {Mohammad Hassan Vali and Tom Bäckström and Arno Solin},
787787
year = {2025},
788788
eprint = {2509.26469},
789789
archivePrefix = {arXiv},
790790
primaryClass = {cs.LG},
791-
url = {https://arxiv.org/abs/2509.26469},
791+
url = {https://arxiv.org/abs/2509.26469},
792792
}
793793
```
794794

795795
```bibtex
796796
@misc{fleuret2025freetransformer,
797-
title = {The Free Transformer},
797+
title = {The Free Transformer},
798798
author = {François Fleuret},
799799
year = {2025},
800800
eprint = {2510.17558},
801801
archivePrefix = {arXiv},
802802
primaryClass = {cs.LG},
803-
url = {https://arxiv.org/abs/2510.17558},
803+
url = {https://arxiv.org/abs/2510.17558},
804804
}
805805
```
806806

807807
```bibtex
808808
@misc{chang2025scalabletrainingvectorquantizednetworks,
809-
title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization},
809+
title = {Scalable Training for Vector-Quantized Networks with 100% Codebook Utilization},
810810
author = {Yifan Chang and Jie Qin and Limeng Qiao and Xiaofeng Wang and Zheng Zhu and Lin Ma and Xingang Wang},
811811
year = {2025},
812812
eprint = {2509.10140},
813813
archivePrefix = {arXiv},
814814
primaryClass = {cs.CV},
815-
url = {https://arxiv.org/abs/2509.10140},
815+
url = {https://arxiv.org/abs/2509.10140},
816816
}
817817
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.21"
3+
version = "1.28.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ convention = "numpy"
2121
[format]
2222
docstring-code-format = true
2323
docstring-code-line-length = 20
24-

tests/test_beam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_topk_and_manual_ema_update():
1515
dim = 256,
1616
codebook_size = 512
1717
)
18-
18+
1919
vq2.load_state_dict(vq1.state_dict())
2020

2121
x = torch.randn(1, 1024, 256)

tests/test_lfq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ def test_lfq_bruteforce_frac_per_sample_entropy(
7474
per_sample_losses[i] = loss_breakdown.per_sample_entropy
7575

7676
# 95% confidence interval
77-
assert abs(per_sample_losses.mean() - true_per_sample_entropy) < (1.96 * (per_sample_losses.std() / math.sqrt(iters)))
77+
assert abs(per_sample_losses.mean() - true_per_sample_entropy) < (1.96 * (per_sample_losses.std() / math.sqrt(iters)))

vector_quantize_pytorch/binary_mapper.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,58 @@ def __init__(
6969

7070
self.deterministic_on_eval = deterministic_on_eval
7171

72+
def binary_entropy(self, logits):
73+
return binary_entropy(logits)
74+
75+
def calc_aux_loss(
76+
self,
77+
logits,
78+
reduce_aux_kl_loss = True
79+
):
80+
logits, inverse_pack_lead_dims = pack_with_inverse(logits, '* bits')
81+
kl_div = self.bits * NAT - self.binary_entropy(logits)
82+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold)
83+
84+
if reduce_aux_kl_loss:
85+
return aux_kl_loss.mean()
86+
87+
return inverse_pack_lead_dims(aux_kl_loss, '*')
88+
89+
def log_prob(
90+
self,
91+
logits,
92+
*,
93+
indices = None,
94+
one_hot = None,
95+
sum_bits = True
96+
):
97+
assert exists(indices) ^ exists(one_hot), 'either indices or one_hot must be provided'
98+
99+
if exists(one_hot):
100+
indices = one_hot.argmax(dim=-1)
101+
102+
# allow for any number of leading dimensions
103+
104+
logits, inverse_pack_lead_dims = pack_with_inverse(logits, '* bits')
105+
indices, _ = pack_with_inverse(indices, '*')
106+
107+
# sampled bits representation
108+
109+
sampled_bits = self.codes[indices]
110+
111+
# calculate log probability
112+
113+
log_probs_1 = F.logsigmoid(logits)
114+
log_probs_0 = F.logsigmoid(-logits)
115+
116+
log_probs = torch.where(sampled_bits, log_probs_1, log_probs_0)
117+
118+
if not sum_bits:
119+
return inverse_pack_lead_dims(log_probs)
120+
121+
log_probs = log_probs.sum(dim = -1)
122+
return inverse_pack_lead_dims(log_probs, '*')
123+
72124
def forward(
73125
self,
74126
logits,
@@ -86,6 +138,7 @@ def forward(
86138

87139
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
88140

141+
orig_logits = logits
89142
# allow for any number of leading dimensions
90143

91144
logits, inverse_pack_lead_dims = pack_with_inverse(logits, '* bits')
@@ -110,17 +163,7 @@ def forward(
110163
aux_kl_loss = self.zero
111164

112165
if calc_aux_loss:
113-
# calculate negative entropy
114-
115-
kl_div = self.bits * NAT - binary_entropy(logits)
116-
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold)
117-
118-
# able to return unreduced kl loss, for use in another project (metacontroller)
119-
120-
if reduce_aux_kl_loss:
121-
aux_kl_loss = aux_kl_loss.mean()
122-
else:
123-
aux_kl_loss = inverse_pack_lead_dims(aux_kl_loss, '*')
166+
aux_kl_loss = self.calc_aux_loss(orig_logits, reduce_aux_kl_loss = reduce_aux_kl_loss)
124167

125168
# maybe straight through
126169

@@ -164,7 +207,13 @@ def forward(
164207
assert indices.shape == (3, 4)
165208
assert aux_loss.shape == (3, 4)
166209

210+
joint_log_prob = binary_mapper.log_prob(logits, indices = indices)
211+
assert joint_log_prob.shape == (3, 4)
212+
213+
joint_log_prob_one_hot = binary_mapper.log_prob(logits, one_hot = sparse_one_hot)
214+
assert torch.allclose(joint_log_prob, joint_log_prob_one_hot)
215+
167216
binary_mapper.eval()
168217
sparse_one_hot1, _ = binary_mapper(logits, deterministic = True)
169218
sparse_one_hot2, _ = binary_mapper(logits, deterministic = True)
170-
assert torch.allclose(sparse_one_hot1, sparse_one_hot2)
219+
assert torch.allclose(sparse_one_hot1, sparse_one_hot2)

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def bound(self, z, eps = 1e-3, hard_clamp = False):
145145
return round_ste(bounded_z) / half_width
146146

147147
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
148-
148+
149149
def symmetry_preserving_bound(self, z, hard_clamp = False):
150150
""" QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
151151
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
@@ -186,7 +186,7 @@ def _scale_and_shift(self, zhat_normalized):
186186

187187
half_width = self._levels // 2
188188
return (zhat_normalized * half_width) + half_width
189-
189+
190190
def _scale_and_shift_inverse(self, zhat):
191191
if self.preserve_symmetry:
192192
return zhat * (2. / (self._levels - 1)) - 1.

vector_quantize_pytorch/latent_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
(default is 1)
5353
codebook_dim (int): the dimension of the codebook.
5454
If levels is a list, codebook_dim is the length of the list.
55-
(default to -1)
55+
(default to -1)
5656
keep_num_codebooks_dim (Optional[bool]): Whether to keep the number of codebooks dimension in the output tensor. If not provided, it is set to True if num_codebooks > 1, otherwise False.
5757
optimize_values (Optional[bool]): Whether to optimize the values of the codebook. If not provided, it is set to True.
5858
"""

vector_quantize_pytorch/residual_sim_vq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
@property
8686
def codebook_size(self):
8787
return first(self.layers).codebook_size
88-
88+
8989
@property
9090
def codebook_dim(self):
9191
return first(self.layers).codebook_dim

vector_quantize_pytorch/residual_vq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __init__(
263263
self.register_buffer('beam_score_weights', tensor(beam_score_quantizer_weights), persistent = False)
264264

265265
# setting up the MLPs for implicit neural codebooks
266-
266+
267267
self.mlps = None
268268

269269
if implicit_neural_codebook:
@@ -285,7 +285,7 @@ def __init__(
285285

286286
for vq in rest_vq:
287287
vq._codebook = codebook
288-
288+
289289
@property
290290
def codebook_size(self):
291291
return self.layers[0].codebook_size

0 commit comments

Comments
 (0)