Skip to content

Commit b2d552d

Browse files
xytpaisijyang
authored andcommitted
Update minimax_m2.py (#820)
1 parent 03e3f68 commit b2d552d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

atom/models/minimax_m2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def forward(
232232
# TP-aware RMSNorm: all-reduce variance across TP ranks so
233233
# normalization uses the global variance (over 6144/1024 dims)
234234
# rather than per-rank variance (768/128 dims).
235-
if qkv.shape[0] <= 64 and self.tp_size > 1:
235+
if qkv.shape[0] <= 256 and self.tp_size > 1:
236236
q, k, v = tensor_model_parallel_fused_qknorm_allreduce(
237237
qkv, self.q_norm.weight, self.k_norm.weight, self.rms_norm_eps
238238
)

0 commit comments

Comments
 (0)