Skip to content

Commit d304ba0

Browse files
committed
issue/224 - feat: use muDNN silu_and_mul to replace elementwise swiglu in moore gpu
1 parent 28945a9 commit d304ba0

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

csrc/models/llama/llama_mlp.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,34 @@ LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
7171
}
7272

7373
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
74-
// 1. Project to gate and up
75-
auto hidden_states_mutable = hidden_states;
76-
auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);
74+
infinicore::Device::Type dev_type = hidden_states->device().getType();
75+
if(dev_type == infinicore::Device::Type::MOORE){
76+
// 1. Project to a single combined gate_up tensor
77+
auto hidden_states_mutable = hidden_states;
78+
auto gate_up = gate_up_proj_->forward(hidden_states_mutable);
7779

78-
// 2. Apply SwiGLU: silu(gate) * up
79-
// Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
80-
// So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
81-
auto intermediate = infinicore::op::swiglu(up, gate);
80+
// 2. Apply the fused silu_and_mul operator
81+
// applies SiLU to the first half, and multiplies it by the second half.
82+
// Mathematically equivalent to: result = SiLU(gate_up[..., :d]) * gate_up[..., d:]
83+
auto intermediate = infinicore::op::silu_and_mul(gate_up);
8284

83-
// 3. Project down
84-
auto output = down_proj_->forward(intermediate);
85+
// 3. Project down
86+
auto output = down_proj_->forward(intermediate);
87+
return output;
88+
} else{
89+
// 1. Project to gate and up
90+
auto hidden_states_mutable = hidden_states;
91+
auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);
8592

86-
return output;
93+
// 2. Apply SwiGLU: silu(gate) * up
94+
// Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
95+
// So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
96+
auto intermediate = infinicore::op::swiglu(up, gate);
97+
98+
// 3. Project down
99+
auto output = down_proj_->forward(intermediate);
100+
return output;
101+
}
87102
}
88103

89104
} // namespace infinilm::models::llama

0 commit comments

Comments
 (0)