Skip to content

Commit d6f30d6

Browse files
yanfeichroot
andauthored
[INTEL_HPU] Moe tensor align (#2321)
* list to tensor 0x80 alignment padding * update gate_moe UT --------- Co-authored-by: root <root@CT13.sh.intel.com>
1 parent f070906 commit d6f30d6

4 files changed

Lines changed: 409 additions & 164 deletions

File tree

backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ void FusedGateMoeKernel(
416416
const phi::DenseTensor& gate_up_weights,
417417
const phi::DenseTensor& down_weights,
418418
const paddle::optional<phi::DenseTensor>& hidden_states_scales,
419-
const paddle::optional<std::vector<phi::DenseTensor>>&
420-
intermediate_hidden_states_scales,
419+
const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
421420
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
422421
const paddle::optional<phi::DenseTensor>& down_weights_scales,
423422
phi::DenseTensor* final_hidden_states,
@@ -468,9 +467,7 @@ void FusedGateMoeKernel(
468467
ct.AddN(down_weights);
469468

470469
if (intermediate_hidden_states_scales) {
471-
for (const auto& t : intermediate_hidden_states_scales.get()) {
472-
ct.Add(t);
473-
}
470+
ct.AddN(intermediate_hidden_states_scales.get());
474471
}
475472
if (gate_up_weights_scales) {
476473
ct.AddN(gate_up_weights_scales.get());
@@ -514,8 +511,7 @@ void CallFusedGateMoeKernel(
514511
const phi::DenseTensor& gate_up_weights,
515512
const phi::DenseTensor& down_weights,
516513
const paddle::optional<phi::DenseTensor>& hidden_states_scales,
517-
const paddle::optional<std::vector<phi::DenseTensor>>&
518-
intermediate_hidden_states_scales,
514+
const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
519515
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
520516
const paddle::optional<phi::DenseTensor>& down_weights_scales,
521517
phi::DenseTensor* final_hidden_states,
@@ -634,7 +630,7 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
634630
*gate_up_weights_tensor,
635631
*down_weights_tensor,
636632
paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
637-
paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
633+
paddle::optional<phi::DenseTensor>(), /* intermediate */
638634
paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales */
639635
paddle::optional<phi::DenseTensor>(), /* down_weights_scales */
640636
final_hidden_states.get(),
@@ -660,8 +656,7 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
660656
const paddle::Tensor& gate_up_weights,
661657
const paddle::Tensor& down_weights,
662658
const paddle::optional<paddle::Tensor>& hidden_states_scales,
663-
const paddle::optional<std::vector<paddle::Tensor>>&
664-
intermediate_hidden_states_scales,
659+
const paddle::optional<paddle::Tensor>& intermediate_hidden_states_scales,
665660
const paddle::Tensor& gate_up_weights_scales,
666661
const paddle::Tensor& down_weights_scales,
667662
const int top_k,
@@ -701,14 +696,16 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
701696
paddle::optional<phi::DenseTensor>(*hidden_states_scales_dt);
702697
}
703698

699+
auto intermediate_hidden_states_scales_tensor =
700+
paddle::optional<phi::DenseTensor>();
704701
bool dynamic_scale = true;
705-
std::vector<phi::DenseTensor> scales_vec;
706702
if (intermediate_hidden_states_scales) {
707703
dynamic_scale = false;
708-
for (const auto& t : intermediate_hidden_states_scales.get()) {
709-
scales_vec.push_back(
710-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
711-
}
704+
auto intermediate_hidden_states_scales_dt = static_cast<phi::DenseTensor*>(
705+
intermediate_hidden_states_scales->impl().get());
706+
intermediate_hidden_states_scales_tensor =
707+
paddle::optional<phi::DenseTensor>(
708+
*intermediate_hidden_states_scales_dt);
712709
}
713710
auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
714711
auto gate_up_weights_scales_dt =
@@ -735,7 +732,7 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
735732
*gate_up_weights_tensor,
736733
*down_weights_tensor,
737734
hidden_states_scales_tensor,
738-
scales_vec,
735+
intermediate_hidden_states_scales_tensor,
739736
gate_up_weights_scales_tensor,
740737
down_weights_scales_tensor,
741738
final_hidden_states.get(),
@@ -816,7 +813,7 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
816813
*gate_up_weights_tensor,
817814
*down_weights_tensor,
818815
paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
819-
paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
816+
paddle::optional<phi::DenseTensor>(), /* intermediate */
820817
gate_up_weights_scales_tensor,
821818
down_weights_scales_tensor,
822819
final_hidden_states.get(),
@@ -887,7 +884,7 @@ PD_BUILD_OP(fused_gate_moe_fp8)
887884
"gate_up_weights",
888885
"down_weights",
889886
paddle::Optional("hidden_states_scales"),
890-
paddle::Optional(paddle::Vec("intermediate_hidden_states_scales")),
887+
paddle::Optional("intermediate_hidden_states_scales"),
891888
"gate_up_weights_scales",
892889
"down_weights_scales"})
893890
.Outputs({"final_hidden_states"})

backends/intel_hpu/kernels/funcs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ class ConvertTensors {
352352
if (addr_offset % 0x80 != 0) {
353353
PADDLE_THROW("Tensor list offset is not algined.");
354354
}
355+
if (dims.size() == 2 && dims[0] == 1) {
356+
// [list_size][1] is padded for 0x80 alignment as
357+
// [list_size][1][padded_num].
358+
// now addr_offset = 0x80;
359+
dims.pop_back();
360+
}
355361

356362
if (is_input) {
357363
for (int64_t tensor_idx = 0; tensor_idx < num_list; tensor_idx++) {

backends/intel_hpu/tests/unittests/test_fused_fp8_qkv_rope.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222

2323
intel_hpus_module_id = os.environ.get("FLAGS_selected_intel_hpus", 4)
24+
paddle.device.set_device(f"intel_hpu:{intel_hpus_module_id}")
2425

2526
paddle.seed(2025)
2627

0 commit comments

Comments
 (0)