Skip to content

Commit 671c793

Browse files
authored
Remove duplicate log_softmax calculation (#158)
Removed duplicate log_softmax call on ref_outputs.
1 parent a108253 commit 671c793

1 file changed

Lines changed: 0 additions & 1 deletion

File tree

src/trainer/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def compute_kl_divergence(model, target_model, inputs):
1818
with torch.no_grad():
1919
ref_outputs = target_model(**inputs)
2020

21-
ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
2221
ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
2322
ref_probs = ref_probs.view(-1, ref_outputs.logits.shape[-1])
2423

0 commit comments

Comments
 (0)