Skip to content

Commit a4fa999

Browse files
committed
make sure to assert out invalid config
1 parent cbf4454 commit a4fa999

4 files changed

Lines changed: 72 additions & 40 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,17 @@ build-backend = "hatchling.build"
4141

4242
[tool.rye]
4343
managed = true
44+
4445
dev-dependencies = [
4546
"ruff>=0.4.2",
4647
"pytest>=8.2.0",
4748
"pytest-cov>=5.0.0",
49+
"x-transformers>=2.16.1",
4850
]
4951

5052
[tool.pytest.ini_options]
5153
pythonpath = [
52-
"."
54+
"."
5355
]
5456

5557
[tool.hatch.metadata]

tests/test_lfq.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,77 @@
11
import torch
22
import pytest
3-
from vector_quantize_pytorch import LFQ
43
import math
5-
"""
6-
testing_strategy:
7-
subdivisions: using masks, using frac_per_sample_entropy < 1
8-
"""
4+
from vector_quantize_pytorch import LFQ
5+
6+
# helpers
97

10-
torch.manual_seed(0)
8+
def exists(val):
9+
return val is not None
10+
11+
# tests
1112

1213
@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5))
13-
@pytest.mark.parametrize('mask', (torch.tensor([False, False]),
14-
torch.tensor([True, False]),
15-
torch.tensor([True, True])))
14+
@pytest.mark.parametrize('mask', (
15+
torch.tensor([False, False]),
16+
torch.tensor([True, False]),
17+
torch.tensor([True, True])
18+
))
1619
def test_masked_lfq(
1720
frac_per_sample_entropy,
1821
mask
1922
):
20-
# you can specify either dim or codebook_size
21-
# if both specified, will be validated against each other
22-
2323
quantizer = LFQ(
24-
codebook_size = 65536, # codebook size, must be a power of 2
25-
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
26-
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
27-
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
24+
codebook_size = 65536,
25+
dim = 16,
26+
entropy_loss_weight = 0.1,
27+
diversity_gamma = 1.,
2828
frac_per_sample_entropy = frac_per_sample_entropy
2929
)
3030

3131
image_feats = torch.randn(2, 16, 32, 32)
3232

33-
ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
33+
ret, _ = quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask)
3434

3535
quantized, indices, _ = ret
3636
assert (quantized == quantizer.indices_to_codes(indices)).all()
3737

3838
@pytest.mark.parametrize('frac_per_sample_entropy', (0.1,))
3939
@pytest.mark.parametrize('iters', (10,))
4040
@pytest.mark.parametrize('mask', (None, torch.tensor([True, False])))
41-
def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask):
41+
def test_lfq_bruteforce_frac_per_sample_entropy(
42+
frac_per_sample_entropy,
43+
iters,
44+
mask
45+
):
4246
image_feats = torch.randn(2, 16, 32, 32)
4347

4448
full_per_sample_entropy_quantizer = LFQ(
45-
codebook_size = 65536, # codebook size, must be a power of 2
46-
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
47-
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
48-
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
49+
codebook_size = 65536,
50+
dim = 16,
51+
entropy_loss_weight = 0.1,
52+
diversity_gamma = 1.,
4953
frac_per_sample_entropy = 1
5054
)
5155

5256
partial_per_sample_entropy_quantizer = LFQ(
53-
codebook_size = 65536, # codebook size, must be a power of 2
54-
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
55-
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
56-
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
57+
codebook_size = 65536,
58+
dim = 16,
59+
entropy_loss_weight = 0.1,
60+
diversity_gamma = 1.,
5761
frac_per_sample_entropy = frac_per_sample_entropy
5862
)
5963

60-
ret, loss_breakdown = full_per_sample_entropy_quantizer(
61-
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask)
64+
ret, loss_breakdown = full_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask)
6265
true_per_sample_entropy = loss_breakdown.per_sample_entropy
6366

6467
per_sample_losses = torch.zeros(iters)
65-
for iter in range(iters):
66-
ret, loss_breakdown = partial_per_sample_entropy_quantizer(
67-
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
68+
69+
for i in range(iters):
70+
ret, loss_breakdown = partial_per_sample_entropy_quantizer(image_feats, inv_temperature = 100., return_loss_breakdown = True, mask = mask)
6871

6972
quantized, indices, _ = ret
7073
assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all()
71-
per_sample_losses[iter] = loss_breakdown.per_sample_entropy
72-
# 95% confidence interval
73-
assert abs(per_sample_losses.mean() - true_per_sample_entropy) \
74-
< (1.96*(per_sample_losses.std() / math.sqrt(iters)))
74+
per_sample_losses[i] = loss_breakdown.per_sample_entropy
7575

76-
print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy))
77-
print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters))))
76+
# 95% confidence interval
77+
assert abs(per_sample_losses.mean() - true_per_sample_entropy) < (1.96 * (per_sample_losses.std() / math.sqrt(iters)))

tests/test_readme.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,31 @@ def test_vq_3d():
506506
assert x.shape == quantized.shape
507507
assert indices.shape == (1, 16, 16, 16)
508508
assert torch.allclose(quantized, quantizer.get_output_from_indices(indices))
509+
510+
def test_fvq():
511+
from vector_quantize_pytorch import VectorQuantize
512+
from x_transformers import ContinuousTransformerWrapper, Encoder
513+
514+
vq_bridge = ContinuousTransformerWrapper(
515+
attn_layers = Encoder(
516+
dim = 256,
517+
depth = 1,
518+
heads = 4,
519+
pre_norm_has_final_norm = False
520+
)
521+
)
522+
523+
vq = VectorQuantize(
524+
dim = 256,
525+
codebook_size = 512,
526+
vq_bridge = vq_bridge,
527+
learnable_codebook = True,
528+
ema_update = False,
529+
in_place_codebook_optimizer = lambda params: torch.optim.SGD(params, lr = 1e-3)
530+
)
531+
532+
x = torch.randn(1, 1024, 256)
533+
quantized, indices, commit_loss = vq(x)
534+
535+
assert quantized.shape == x.shape
536+
assert indices.shape == (1, 1024)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,8 @@ def __init__(
813813

814814
# defaults
815815

816-
ema_update = default(ema_update, not directional_reparam)
817-
learnable_codebook = default(learnable_codebook, directional_reparam)
816+
ema_update = default(ema_update, not directional_reparam and not exists(vq_bridge))
817+
learnable_codebook = default(learnable_codebook, directional_reparam or exists(vq_bridge))
818818
rotation_trick = default(rotation_trick, not directional_reparam and dim > 1) # only use rotation trick if feature dimension greater than 1
819819

820820
# basic variables
@@ -865,6 +865,8 @@ def __init__(
865865

866866
assert not (straight_through and learnable_codebook), 'gumbel straight through not allowed when learning the codebook'
867867
assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update'
868+
assert not (exists(vq_bridge) and not learnable_codebook), 'learnable_codebook must be set to True if vq_bridge is passed in'
869+
assert not (exists(vq_bridge) and ema_update), 'ema_update must be False if vq_bridge is passed in'
868870

869871
assert 0 <= sync_update_v <= 1.
870872
assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on'

0 commit comments

Comments
 (0)