Skip to content

Commit 7a4e13c

Browse files
committed
need unreduced kl loss for the metacontroller project
1 parent fb78d77 commit 7a4e13c

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.19"
3+
version = "1.27.20"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/binary_mapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(
7777
calc_aux_loss = None,
7878
deterministic = None,
7979
return_indices = False,
80+
reduce_aux_kl_loss = True
8081
):
8182
deterministic = default(deterministic, self.deterministic_on_eval and not self.training)
8283

@@ -112,7 +113,14 @@ def forward(
112113
# calculate negative entropy
113114

114115
kl_div = self.bits * NAT - binary_entropy(logits)
115-
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
116+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold)
117+
118+
# able to return unreduced kl loss, for use in another project (metacontroller)
119+
120+
if reduce_aux_kl_loss:
121+
aux_kl_loss = aux_kl_loss.mean()
122+
else:
123+
aux_kl_loss = inverse_pack_lead_dims(aux_kl_loss, '*')
116124

117125
# maybe straight through
118126

@@ -150,11 +158,11 @@ def forward(
150158

151159
logits = torch.randn(3, 4, 8)
152160

153-
sparse_one_hot, indices, aux_loss = binary_mapper(logits, return_indices = True)
161+
sparse_one_hot, indices, aux_loss = binary_mapper(logits, return_indices = True, reduce_aux_kl_loss = False)
154162

155163
assert sparse_one_hot.shape == (3, 4, 2 ** 8)
156164
assert indices.shape == (3, 4)
157-
assert aux_loss.numel() == 1
165+
assert aux_loss.shape == (3, 4)
158166

159167
binary_mapper.eval()
160168
sparse_one_hot1, _ = binary_mapper(logits, deterministic = True)

0 commit comments

Comments
 (0)