We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a108253 commit 671c793Copy full SHA for 671c793
1 file changed
src/trainer/utils.py
@@ -18,7 +18,6 @@ def compute_kl_divergence(model, target_model, inputs):
18
with torch.no_grad():
19
ref_outputs = target_model(**inputs)
20
21
- ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
22
ref_probs = F.log_softmax(ref_outputs.logits, dim=-1)
23
ref_probs = ref_probs.view(-1, ref_outputs.logits.shape[-1])
24
0 commit comments