@@ -416,8 +416,7 @@ void FusedGateMoeKernel(
416416 const phi::DenseTensor& gate_up_weights,
417417 const phi::DenseTensor& down_weights,
418418 const paddle::optional<phi::DenseTensor>& hidden_states_scales,
419- const paddle::optional<std::vector<phi::DenseTensor>>&
420- intermediate_hidden_states_scales,
419+ const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
421420 const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
422421 const paddle::optional<phi::DenseTensor>& down_weights_scales,
423422 phi::DenseTensor* final_hidden_states,
@@ -468,9 +467,7 @@ void FusedGateMoeKernel(
468467 ct.AddN (down_weights);
469468
470469 if (intermediate_hidden_states_scales) {
471- for (const auto & t : intermediate_hidden_states_scales.get ()) {
472- ct.Add (t);
473- }
470+ ct.AddN (intermediate_hidden_states_scales.get ());
474471 }
475472 if (gate_up_weights_scales) {
476473 ct.AddN (gate_up_weights_scales.get ());
@@ -514,8 +511,7 @@ void CallFusedGateMoeKernel(
514511 const phi::DenseTensor& gate_up_weights,
515512 const phi::DenseTensor& down_weights,
516513 const paddle::optional<phi::DenseTensor>& hidden_states_scales,
517- const paddle::optional<std::vector<phi::DenseTensor>>&
518- intermediate_hidden_states_scales,
514+ const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
519515 const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
520516 const paddle::optional<phi::DenseTensor>& down_weights_scales,
521517 phi::DenseTensor* final_hidden_states,
@@ -634,7 +630,7 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
634630 *gate_up_weights_tensor,
635631 *down_weights_tensor,
636632 paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
637- paddle::optional<std::vector< phi::DenseTensor> >(), /* intermediate */
633+ paddle::optional<phi::DenseTensor>(), /* intermediate */
638634 paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales */
639635 paddle::optional<phi::DenseTensor>(), /* down_weights_scales */
640636 final_hidden_states.get (),
@@ -660,8 +656,7 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
660656 const paddle::Tensor& gate_up_weights,
661657 const paddle::Tensor& down_weights,
662658 const paddle::optional<paddle::Tensor>& hidden_states_scales,
663- const paddle::optional<std::vector<paddle::Tensor>>&
664- intermediate_hidden_states_scales,
659+ const paddle::optional<paddle::Tensor>& intermediate_hidden_states_scales,
665660 const paddle::Tensor& gate_up_weights_scales,
666661 const paddle::Tensor& down_weights_scales,
667662 const int top_k,
@@ -701,14 +696,16 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
701696 paddle::optional<phi::DenseTensor>(*hidden_states_scales_dt);
702697 }
703698
699+ auto intermediate_hidden_states_scales_tensor =
700+ paddle::optional<phi::DenseTensor>();
704701 bool dynamic_scale = true ;
705- std::vector<phi::DenseTensor> scales_vec;
706702 if (intermediate_hidden_states_scales) {
707703 dynamic_scale = false ;
708- for (const auto & t : intermediate_hidden_states_scales.get ()) {
709- scales_vec.push_back (
710- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
711- }
704+ auto intermediate_hidden_states_scales_dt = static_cast <phi::DenseTensor*>(
705+ intermediate_hidden_states_scales->impl ().get ());
706+ intermediate_hidden_states_scales_tensor =
707+ paddle::optional<phi::DenseTensor>(
708+ *intermediate_hidden_states_scales_dt);
712709 }
713710 auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
714711 auto gate_up_weights_scales_dt =
@@ -735,7 +732,7 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
735732 *gate_up_weights_tensor,
736733 *down_weights_tensor,
737734 hidden_states_scales_tensor,
738- scales_vec ,
735+ intermediate_hidden_states_scales_tensor ,
739736 gate_up_weights_scales_tensor,
740737 down_weights_scales_tensor,
741738 final_hidden_states.get (),
@@ -816,7 +813,7 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
816813 *gate_up_weights_tensor,
817814 *down_weights_tensor,
818815 paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
819- paddle::optional<std::vector< phi::DenseTensor> >(), /* intermediate */
816+ paddle::optional<phi::DenseTensor>(), /* intermediate */
820817 gate_up_weights_scales_tensor,
821818 down_weights_scales_tensor,
822819 final_hidden_states.get (),
@@ -887,7 +884,7 @@ PD_BUILD_OP(fused_gate_moe_fp8)
887884 " gate_up_weights" ,
888885 " down_weights" ,
889886 paddle::Optional (" hidden_states_scales" ),
890- paddle::Optional (paddle::Vec ( " intermediate_hidden_states_scales" ) ),
887+ paddle::Optional (" intermediate_hidden_states_scales" ),
891888 " gate_up_weights_scales" ,
892889 " down_weights_scales" })
893890 .Outputs({" final_hidden_states" })
0 commit comments