diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index e38388de4..b53aa7337 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -756,7 +756,7 @@ def create_weights( if layer.has_bias: w13_bias = atom_parameter( - torch.empty( + torch.zeros( num_experts, 2 * intermediate_size_per_partition_after_pad, dtype=torch.bfloat16, @@ -793,7 +793,7 @@ def create_weights( if layer.has_bias: w2_bias = atom_parameter( - torch.empty( + torch.zeros( num_experts, hidden_size, dtype=torch.bfloat16,