We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
moe_normalize_expert_weights
top_k=1
1 parent bcb4979 commit 04e4f1fCopy full SHA for 04e4f1f
1 file changed
megablocks/layers/router.py
@@ -45,10 +45,9 @@ def jitter(self, x):
45
46
def _top_k(self, scores):
47
if self.args.moe_top_k == 1:
48
- return scores.max(dim=-1)
+ return scores.max(dim=-1,keepdim=True)
49
return torch.topk(scores, self.args.moe_top_k, dim=-1)
50
51
-
52
def forward(self, x):
53
if self.training and self.args.moe_jitter_eps is not None:
54
x = x * self.jitter(x)
0 commit comments