Skip to content

Commit f08f20a

Browse files
authored
ggml-cpu: fuse RMS_NORM + MUL on CPU backend (ggml-org#22423)
1 parent 07eaf91 commit f08f20a

3 files changed

Lines changed: 115 additions & 17 deletions

File tree

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2965,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan(
29652965
return cplan;
29662966
}
29672967

2968+
2969+
// Try to fuse the current node with subsequent nodes for better performance.
2970+
// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied.
2971+
static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards
2972+
2973+
static int ggml_cpu_try_fuse_ops(
2974+
const struct ggml_cgraph * cgraph,
2975+
const int node_n,
2976+
const struct ggml_compute_params * params,
2977+
const struct ggml_cplan * cplan) {
2978+
2979+
if (ggml_cpu_disable_fusion || cplan->use_ref) {
2980+
return 0;
2981+
}
2982+
2983+
struct ggml_tensor * node = cgraph->nodes[node_n];
2984+
2985+
if (node->op == GGML_OP_RMS_NORM) {
2986+
// RMS_NORM + MUL fusion
2987+
const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL };
2988+
if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) {
2989+
struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1];
2990+
const struct ggml_tensor * mul_w = (mul_node->src[0] == node)
2991+
? mul_node->src[1] : mul_node->src[0];
2992+
if (node->src[0]->type == GGML_TYPE_F32 &&
2993+
mul_node->type == GGML_TYPE_F32 &&
2994+
mul_w->type == GGML_TYPE_F32 &&
2995+
mul_w->ne[0] == node->ne[0] &&
2996+
mul_w->nb[0] == sizeof(float)) {
2997+
2998+
ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node);
2999+
return 1;
3000+
}
3001+
}
3002+
}
3003+
3004+
return 0;
3005+
}
3006+
29683007
static thread_ret_t ggml_graph_compute_thread(void * data) {
29693008
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
29703009
struct ggml_threadpool * tp = state->threadpool;
@@ -3001,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
30013040
continue;
30023041
}
30033042

3004-
ggml_compute_forward(&params, node);
3043+
// TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time
3044+
// Try fused ops, fall back to normal compute
3045+
const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, &params, cplan);
3046+
if (n_fused > 0) {
3047+
node_n += n_fused;
3048+
} else {
3049+
ggml_compute_forward(&params, node);
3050+
}
30053051

30063052
if (state->ith == 0 && cplan->abort_callback &&
30073053
cplan->abort_callback(cplan->abort_callback_data)) {
@@ -3763,6 +3809,11 @@ void ggml_cpu_init(void) {
37633809
ggml_init_riscv_arch_features();
37643810
#endif
37653811

3812+
{
3813+
const char * env = getenv("GGML_CPU_DISABLE_FUSION");
3814+
ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1);
3815+
}
3816+
37663817
is_first_call = false;
37673818
}
37683819

ggml/src/ggml-cpu/ops.cpp

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm(
37133713

37143714
// ggml_compute_forward_group_rms_norm
37153715

3716+
// fusion kinds that can be combined with the rms_norm computation in a single pass.
3717+
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
3718+
enum ggml_rms_norm_fuse_op {
3719+
GGML_RMS_NORM_FUSE_OP_NONE,
3720+
GGML_RMS_NORM_FUSE_OP_MUL,
3721+
};
3722+
3723+
template <ggml_rms_norm_fuse_op FUSE_OP>
37163724
static void ggml_compute_forward_rms_norm_f32(
37173725
const ggml_compute_params * params,
3718-
ggml_tensor * dst) {
3726+
ggml_tensor * dst_rms_norm,
3727+
ggml_tensor * dst_fused = nullptr) {
37193728

3720-
const ggml_tensor * src0 = dst->src[0];
3729+
const ggml_tensor * src0 = dst_rms_norm->src[0];
3730+
const ggml_tensor * src1 = nullptr;
3731+
ggml_tensor * dst = dst_rms_norm;
3732+
3733+
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3734+
src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
3735+
dst = dst_fused;
3736+
}
37213737

37223738
GGML_ASSERT(ggml_are_same_shape(src0, dst));
37233739

@@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32(
37263742
const int ith = params->ith;
37273743
const int nth = params->nth;
37283744

3729-
GGML_TENSOR_UNARY_OP_LOCALS
3745+
GGML_TENSOR_BINARY_OP_LOCALS
37303746

37313747
float eps;
3732-
memcpy(&eps, dst->op_params, sizeof(float));
3733-
3748+
memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
37343749
GGML_ASSERT(eps >= 0.0f);
37353750

37363751
// TODO: optimize
@@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32(
37403755
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
37413756

37423757
ggml_float sum = 0.0;
3758+
// worth switching to explicit SIMD?
37433759
for (int64_t i00 = 0; i00 < ne00; i00++) {
37443760
sum += (ggml_float)(x[i00] * x[i00]);
37453761
}
37463762

3747-
const float mean = sum/ne00;
3748-
3749-
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3750-
3751-
memcpy(y, x, ne00 * sizeof(float));
3752-
// for (int i00 = 0; i00 < ne00; i00++) {
3753-
// y[i00] = x[i00];
3754-
// }
3755-
3763+
const float mean = sum/ne00;
37563764
const float scale = 1.0f/sqrtf(mean + eps);
37573765

37583766
// if you hit this, likely you got an inf somewhere earlier
37593767
assert(scale > 0.0f);
37603768

3761-
ggml_vec_scale_f32(ne00, y, scale);
3769+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3770+
3771+
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3772+
const int64_t i11 = i01 % ne11;
3773+
const int64_t i12 = i02 % ne12;
3774+
const int64_t i13 = i03 % ne13;
3775+
const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3776+
3777+
for (int64_t i00 = 0; i00 < ne00; i00++) {
3778+
y[i00] = x[i00] * scale * w[i00];
3779+
}
3780+
} else {
3781+
memcpy(y, x, ne00 * sizeof(float));
3782+
ggml_vec_scale_f32(ne00, y, scale);
3783+
}
37623784
}
37633785
}
37643786
}
@@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm(
37733795
switch (src0->type) {
37743796
case GGML_TYPE_F32:
37753797
{
3776-
ggml_compute_forward_rms_norm_f32(params, dst);
3798+
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
3799+
} break;
3800+
default:
3801+
{
3802+
GGML_ABORT("fatal error");
3803+
}
3804+
}
3805+
}
3806+
3807+
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
3808+
// This avoids materializing the intermediate rms_norm result in memory.
3809+
void ggml_compute_forward_rms_norm_mul_fused(
3810+
const ggml_compute_params * params,
3811+
ggml_tensor * dst_rms_norm,
3812+
ggml_tensor * dst_mul) {
3813+
3814+
GGML_ASSERT(dst_mul != nullptr);
3815+
GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
3816+
3817+
const ggml_tensor * src0 = dst_rms_norm->src[0];
3818+
3819+
switch (src0->type) {
3820+
case GGML_TYPE_F32:
3821+
{
3822+
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
37773823
} break;
37783824
default:
37793825
{

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru
4444
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4545
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4646
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
47+
void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul);
4748
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4849
void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
4950
void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);

0 commit comments

Comments
 (0)