11import torch
22import pytest
3- from vector_quantize_pytorch import LFQ
43import math
5- """
6- testing_strategy:
7- subdivisions: using masks, using frac_per_sample_entropy < 1
8- """
4+ from vector_quantize_pytorch import LFQ
5+
6+ # helpers
97
10- torch .manual_seed (0 )
8+ def exists (val ):
9+ return val is not None
10+
11+ # tests
1112
1213@pytest .mark .parametrize ('frac_per_sample_entropy' , (1. , 0.5 ))
13- @pytest .mark .parametrize ('mask' , (torch .tensor ([False , False ]),
14- torch .tensor ([True , False ]),
15- torch .tensor ([True , True ])))
14+ @pytest .mark .parametrize ('mask' , (
15+ torch .tensor ([False , False ]),
16+ torch .tensor ([True , False ]),
17+ torch .tensor ([True , True ])
18+ ))
1619def test_masked_lfq (
1720 frac_per_sample_entropy ,
1821 mask
1922):
20- # you can specify either dim or codebook_size
21- # if both specified, will be validated against each other
22-
2323 quantizer = LFQ (
24- codebook_size = 65536 , # codebook size, must be a power of 2
25- dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
26- entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
27- diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
24+ codebook_size = 65536 ,
25+ dim = 16 ,
26+ entropy_loss_weight = 0.1 ,
27+ diversity_gamma = 1. ,
2828 frac_per_sample_entropy = frac_per_sample_entropy
2929 )
3030
3131 image_feats = torch .randn (2 , 16 , 32 , 32 )
3232
33- ret , loss_breakdown = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
33+ ret , _ = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
3434
3535 quantized , indices , _ = ret
3636 assert (quantized == quantizer .indices_to_codes (indices )).all ()
3737
3838@pytest .mark .parametrize ('frac_per_sample_entropy' , (0.1 ,))
3939@pytest .mark .parametrize ('iters' , (10 ,))
4040@pytest .mark .parametrize ('mask' , (None , torch .tensor ([True , False ])))
41- def test_lfq_bruteforce_frac_per_sample_entropy (frac_per_sample_entropy , iters , mask ):
41+ def test_lfq_bruteforce_frac_per_sample_entropy (
42+ frac_per_sample_entropy ,
43+ iters ,
44+ mask
45+ ):
4246 image_feats = torch .randn (2 , 16 , 32 , 32 )
4347
4448 full_per_sample_entropy_quantizer = LFQ (
45- codebook_size = 65536 , # codebook size, must be a power of 2
46- dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
47- entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
48- diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
49+ codebook_size = 65536 ,
50+ dim = 16 ,
51+ entropy_loss_weight = 0.1 ,
52+ diversity_gamma = 1. ,
4953 frac_per_sample_entropy = 1
5054 )
5155
5256 partial_per_sample_entropy_quantizer = LFQ (
53- codebook_size = 65536 , # codebook size, must be a power of 2
54- dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
55- entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
56- diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
57+ codebook_size = 65536 ,
58+ dim = 16 ,
59+ entropy_loss_weight = 0.1 ,
60+ diversity_gamma = 1. ,
5761 frac_per_sample_entropy = frac_per_sample_entropy
5862 )
5963
60- ret , loss_breakdown = full_per_sample_entropy_quantizer (
61- image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
64+ ret , loss_breakdown = full_per_sample_entropy_quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
6265 true_per_sample_entropy = loss_breakdown .per_sample_entropy
6366
6467 per_sample_losses = torch .zeros (iters )
65- for iter in range ( iters ):
66- ret , loss_breakdown = partial_per_sample_entropy_quantizer (
67- image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
68+
69+ for i in range ( iters ):
70+ ret , loss_breakdown = partial_per_sample_entropy_quantizer ( image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
6871
6972 quantized , indices , _ = ret
7073 assert (quantized == partial_per_sample_entropy_quantizer .indices_to_codes (indices )).all ()
71- per_sample_losses [iter ] = loss_breakdown .per_sample_entropy
72- # 95% confidence interval
73- assert abs (per_sample_losses .mean () - true_per_sample_entropy ) \
74- < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
74+ per_sample_losses [i ] = loss_breakdown .per_sample_entropy
7575
76- print ( "difference: " , abs ( per_sample_losses . mean () - true_per_sample_entropy ))
77- print ( "std error:" , (1.96 * (per_sample_losses .std () / math .sqrt (iters ) )))
76+ # 95% confidence interval
77+ assert abs ( per_sample_losses . mean () - true_per_sample_entropy ) < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
0 commit comments