Skip to content

Commit b75afb4

Browse files
[rocm-libraries] ROCm/rocm-libraries#6118 (commit 2c7dcf7)
projects/composablekernel: add SwigluStep support for MoE blockscale (#6118) ## Summary - add `swiglustep_and_mul` to the composablekernel MoE blockscale activation enum - implement the corresponding blockscale epilogue path for `SwigluStep` - keep existing `silu` and `gelu` paths unchanged ## Scope This PR covers the classic composablekernel blockscale MoE path under `projects/composablekernel`. This is separate from the `ck_tile` / FlatMM path being discussed in ROCm/rocm-libraries#5992. ## Motivation `Step-3.5-Flash-FP8` uses `SwigluStep` in its MoE MLP path. The dependent AITER change needs native support for this activation in the classic composablekernel MoE blockscale path. ## Validation - patch is limited to two composablekernel files under `projects/composablekernel` - existing `silu` / `gelu` paths are unchanged - dependent AITER runtime validation hit the classic CK 2-stage path with AITER MoE enabled
1 parent eaaed3e commit b75afb4

2 files changed

Lines changed: 41 additions & 2 deletions

File tree

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ namespace ck {
2828

2929
enum Activation
3030
{
31-
gelu_and_mul = 0,
32-
silu_and_mul = 1
31+
gelu_and_mul = 0,
32+
silu_and_mul = 1,
33+
swiglustep_and_mul = 2
3334
};
3435

3536
template <typename ALayout,

include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,25 @@ struct GridwiseMoeGemmBlockScale
15921592
tensor_operation::element_wise::Silu{}(gate, gate);
15931593
c_thread_buf(cidx) = gate * up;
15941594
}
1595+
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
1596+
{
1597+
float gate = c_thread_buf[cidx];
1598+
float up = c_thread_buf_up[cidx];
1599+
if constexpr(MulRoutedWeight)
1600+
{
1601+
gate = gate * topk_weight;
1602+
up = up * topk_weight;
1603+
}
1604+
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
1605+
{
1606+
gate *= 16;
1607+
up *= 16;
1608+
}
1609+
tensor_operation::element_wise::Silu{}(gate, gate);
1610+
gate = gate < 7.0f ? gate : 7.0f;
1611+
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
1612+
c_thread_buf(cidx) = gate * up;
1613+
}
15951614
else if(ActivationOperation == Activation::gelu_and_mul)
15961615
{
15971616
float gate = c_thread_buf[cidx];
@@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale
21182137
tensor_operation::element_wise::Silu{}(gate, gate);
21192138
c_thread_buf(cidx) = gate * up;
21202139
}
2140+
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
2141+
{
2142+
float gate = c_thread_buf[cidx];
2143+
float up = c_thread_buf_up[cidx];
2144+
if constexpr(MulRoutedWeight)
2145+
{
2146+
gate = gate * topk_weight;
2147+
up = up * topk_weight;
2148+
}
2149+
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
2150+
{
2151+
gate *= 16;
2152+
up *= 16;
2153+
}
2154+
tensor_operation::element_wise::Silu{}(gate, gate);
2155+
gate = gate < 7.0f ? gate : 7.0f;
2156+
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
2157+
c_thread_buf(cidx) = gate * up;
2158+
}
21212159
else if(ActivationOperation == Activation::gelu_and_mul)
21222160
{
21232161
float gate = c_thread_buf[cidx];

0 commit comments

Comments
 (0)