@@ -77,6 +77,7 @@ def forward(
7777 calc_aux_loss = None ,
7878 deterministic = None ,
7979 return_indices = False ,
80+ reduce_aux_kl_loss = True
8081 ):
8182 deterministic = default (deterministic , self .deterministic_on_eval and not self .training )
8283
@@ -112,7 +113,14 @@ def forward(
112113 # calculate negative entropy
113114
114115 kl_div = self .bits * NAT - binary_entropy (logits )
115- aux_kl_loss = F .relu (kl_div - self .kl_loss_threshold ).mean ()
116+ aux_kl_loss = F .relu (kl_div - self .kl_loss_threshold )
117+
118+ # able to return unreduced kl loss, for use in another project (metacontroller)
119+
120+ if reduce_aux_kl_loss :
121+ aux_kl_loss = aux_kl_loss .mean ()
122+ else :
123+ aux_kl_loss = inverse_pack_lead_dims (aux_kl_loss , '*' )
116124
117125 # maybe straight through
118126
@@ -150,11 +158,11 @@ def forward(
150158
151159 logits = torch .randn (3 , 4 , 8 )
152160
153- sparse_one_hot , indices , aux_loss = binary_mapper (logits , return_indices = True )
161+ sparse_one_hot , indices , aux_loss = binary_mapper (logits , return_indices = True , reduce_aux_kl_loss = False )
154162
155163 assert sparse_one_hot .shape == (3 , 4 , 2 ** 8 )
156164 assert indices .shape == (3 , 4 )
157- assert aux_loss .numel () == 1
165+ assert aux_loss .shape == ( 3 , 4 )
158166
159167 binary_mapper .eval ()
160168 sparse_one_hot1 , _ = binary_mapper (logits , deterministic = True )
0 commit comments