@@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm(
37133713
37143714// ggml_compute_forward_group_rms_norm
37153715
3716+ // fusion kinds that can be combined with the rms_norm computation in a single pass.
3717+ // extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
3718+ enum ggml_rms_norm_fuse_op {
3719+ GGML_RMS_NORM_FUSE_OP_NONE,
3720+ GGML_RMS_NORM_FUSE_OP_MUL,
3721+ };
3722+
3723+ template <ggml_rms_norm_fuse_op FUSE_OP>
37163724static void ggml_compute_forward_rms_norm_f32 (
37173725 const ggml_compute_params * params,
3718- ggml_tensor * dst) {
3726+ ggml_tensor * dst_rms_norm,
3727+ ggml_tensor * dst_fused = nullptr ) {
37193728
3720- const ggml_tensor * src0 = dst->src [0 ];
3729+ const ggml_tensor * src0 = dst_rms_norm->src [0 ];
3730+ const ggml_tensor * src1 = nullptr ;
3731+ ggml_tensor * dst = dst_rms_norm;
3732+
3733+ if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3734+ src1 = (dst_fused->src [0 ] == dst_rms_norm) ? dst_fused->src [1 ] : dst_fused->src [0 ];
3735+ dst = dst_fused;
3736+ }
37213737
37223738 GGML_ASSERT (ggml_are_same_shape (src0, dst));
37233739
@@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32(
37263742 const int ith = params->ith ;
37273743 const int nth = params->nth ;
37283744
3729- GGML_TENSOR_UNARY_OP_LOCALS
3745+ GGML_TENSOR_BINARY_OP_LOCALS
37303746
37313747 float eps;
3732- memcpy (&eps, dst->op_params , sizeof (float ));
3733-
3748+ memcpy (&eps, dst_rms_norm->op_params , sizeof (float ));
37343749 GGML_ASSERT (eps >= 0 .0f );
37353750
37363751 // TODO: optimize
@@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32(
37403755 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
37413756
37423757 ggml_float sum = 0.0 ;
3758+ // worth switching to explicit SIMD?
37433759 for (int64_t i00 = 0 ; i00 < ne00; i00++) {
37443760 sum += (ggml_float)(x[i00] * x[i00]);
37453761 }
37463762
3747- const float mean = sum/ne00;
3748-
3749- float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3750-
3751- memcpy (y, x, ne00 * sizeof (float ));
3752- // for (int i00 = 0; i00 < ne00; i00++) {
3753- // y[i00] = x[i00];
3754- // }
3755-
3763+ const float mean = sum/ne00;
37563764 const float scale = 1 .0f /sqrtf (mean + eps);
37573765
37583766 // if you hit this, likely you got an inf somewhere earlier
37593767 assert (scale > 0 .0f );
37603768
3761- ggml_vec_scale_f32 (ne00, y, scale);
3769+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3770+
3771+ if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3772+ const int64_t i11 = i01 % ne11;
3773+ const int64_t i12 = i02 % ne12;
3774+ const int64_t i13 = i03 % ne13;
3775+ const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3776+
3777+ for (int64_t i00 = 0 ; i00 < ne00; i00++) {
3778+ y[i00] = x[i00] * scale * w[i00];
3779+ }
3780+ } else {
3781+ memcpy (y, x, ne00 * sizeof (float ));
3782+ ggml_vec_scale_f32 (ne00, y, scale);
3783+ }
37623784 }
37633785 }
37643786 }
@@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm(
37733795 switch (src0->type ) {
37743796 case GGML_TYPE_F32:
37753797 {
3776- ggml_compute_forward_rms_norm_f32 (params, dst);
3798+ ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
3799+ } break ;
3800+ default :
3801+ {
3802+ GGML_ABORT (" fatal error" );
3803+ }
3804+ }
3805+ }
3806+
3807+ // Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
3808+ // This avoids materializing the intermediate rms_norm result in memory.
3809+ void ggml_compute_forward_rms_norm_mul_fused (
3810+ const ggml_compute_params * params,
3811+ ggml_tensor * dst_rms_norm,
3812+ ggml_tensor * dst_mul) {
3813+
3814+ GGML_ASSERT (dst_mul != nullptr );
3815+ GGML_ASSERT (dst_mul->src [0 ] == dst_rms_norm || dst_mul->src [1 ] == dst_rms_norm);
3816+
3817+ const ggml_tensor * src0 = dst_rms_norm->src [0 ];
3818+
3819+ switch (src0->type ) {
3820+ case GGML_TYPE_F32:
3821+ {
3822+ ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
37773823 } break ;
37783824 default :
37793825 {
0 commit comments