@@ -71,19 +71,34 @@ LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
7171}
7272
7373infinicore::Tensor LlamaMLP::forward (const infinicore::Tensor &hidden_states) const {
74- // 1. Project to gate and up
75- auto hidden_states_mutable = hidden_states;
76- auto [gate, up] = gate_up_proj_->forward_split (hidden_states_mutable);
74+ infinicore::Device::Type dev_type = hidden_states->device ().getType ();
75+ if (dev_type == infinicore::Device::Type::MOORE){
76+ // 1. Project to a single combined gate_up tensor
77+ auto hidden_states_mutable = hidden_states;
78+ auto gate_up = gate_up_proj_->forward (hidden_states_mutable);
7779
78- // 2. Apply SwiGLU: silu(gate) * up
79- // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
80- // So we pass (up, gate) to get the correct result: gate * sigmoid(gate ) * up
81- auto intermediate = infinicore::op::swiglu (up, gate );
80+ // 2. Apply the fused silu_and_mul operator
81+ // applies SiLU to the first half, and multiplies it by the second half.
82+ // Mathematically equivalent to: result = SiLU(gate_up[..., :d] ) * gate_up[..., d:]
83+ auto intermediate = infinicore::op::silu_and_mul (gate_up );
8284
83- // 3. Project down
84- auto output = down_proj_->forward (intermediate);
85+ // 3. Project down
86+ auto output = down_proj_->forward (intermediate);
87+ return output;
88+ } else {
89+ // 1. Project to gate and up
90+ auto hidden_states_mutable = hidden_states;
91+ auto [gate, up] = gate_up_proj_->forward_split (hidden_states_mutable);
8592
86- return output;
93+ // 2. Apply SwiGLU: silu(gate) * up
94+ // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
95+ // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
96+ auto intermediate = infinicore::op::swiglu (up, gate);
97+
98+ // 3. Project down
99+ auto output = down_proj_->forward (intermediate);
100+ return output;
101+ }
87102}
88103
89104} // namespace infinilm::models::llama
0 commit comments