Skip to content

Commit 7b1d172

Browse files
committed
reduction fuse opt in RMS normalization layer for llama after the recent half float update
1 parent 68f7d1f commit 7b1d172

3 files changed

Lines changed: 90 additions & 4 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,84 @@ public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatAr
408408
}
409409
}
410410

411+
/**
412+
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
413+
* then it applies the computed normalization factor to input and weight elements.
414+
*
415+
* <p>
416+
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
417+
*
418+
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
419+
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
420+
*
421+
* @param context
422+
* Kernel execution context
423+
* @param outputFP16
424+
* Half float array to store partial sums and final normalization factor
425+
* @param x
426+
* Input array to normalize
427+
* @param weights
428+
* Weight values for each element
429+
* @param temp
430+
* Temporary array containing normalization factor at index 0
431+
* @param size
432+
* Number of elements to process
433+
* @param ermsNorm
434+
* Epsilon value squared for numerical stability
435+
* @param localMemSize
436+
* Size of local memory allocation (must match work group size)
437+
*/
438+
439+
public static void reductionOneBlockWithLayerFuseFP16(KernelContext context, HalfFloatArray outputFP16, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
440+
int gid = context.globalIdx;
441+
int lid = context.localIdx;
442+
int groupId = context.groupIdx;
443+
int groupSize = context.localGroupSizeX;
444+
445+
// Allocate local memory with the provided size
446+
float[] localX = context.allocateFloatLocalArray(localMemSize);
447+
448+
// Load input value and compute square
449+
if (gid < size) {
450+
float v = x.get(gid);
451+
localX[lid] = v * v;
452+
} else {
453+
localX[lid] = 0.0f;
454+
}
455+
456+
// Perform parallel reduction within the work group
457+
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
458+
context.localBarrier();
459+
if (lid < stride) {
460+
localX[lid] += localX[lid + stride];
461+
}
462+
}
463+
464+
// Each workgroup stores its partial sum in a different location
465+
if (lid == 0) {
466+
// Store the partial sum from each workgroup
467+
temp.set(groupId, localX[0]);
468+
}
469+
470+
context.globalBarrier();
471+
472+
float localss = 0.0f;
473+
int numGroups = (size + groupSize - 1) / groupSize;
474+
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
475+
localss += temp.get(i);
476+
}
477+
localss /= size;
478+
localss += ermsNorm;
479+
localss = 1.0f / TornadoMath.sqrt(localss);
480+
481+
if (gid < size) {
482+
float in = x.get(gid);
483+
float w = weights.get(gid);
484+
outputFP16.set(gid, new HalfFloat(w * (localss * in)));
485+
}
486+
}
487+
488+
411489
/**
412490
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
413491
* <p>

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
5050
for (int i = 0; i < config.numberOfLayers(); i++) {
5151
// === Attention Block ===
5252
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
53-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
53+
//tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
5454
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker);
5555
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
5656
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
@@ -199,6 +199,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
199199
// === Attention Block ===
200200
// RMS Normalization
201201
unifiedLayer.task("attn_rms_reduce",
202+
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuseFP16,
203+
context, state.wrapXbFP16, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
204+
config.dim(), config.rmsNormEps(), state.localSize);
205+
/*unifiedLayer.task("attn_rms_reduce",
202206
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
203207
context, state.temp, state.wrapX,
204208
config.dim(), config.rmsNormEps(), state.localSize);
@@ -212,7 +216,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
212216
unifiedLayer.task("attn_rms_apply_fp16",
213217
TransformerComputeKernels::mapContextWithQuantize,
214218
context, state.wrapXbFP16, state.wrapX,
215-
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
219+
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);*/
216220

217221
// QKV Projection (fused)
218222
unifiedLayer.task("qkv_projection",

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
161161
// === Attention Block ===
162162
// RMS Normalization
163163
unifiedLayer.task("attn_rms_reduce",
164+
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse,
165+
context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
166+
config.dim(), config.rmsNormEps(), state.localSize);
167+
/*unifiedLayer.task("attn_rms_reduce",
164168
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
165169
context, state.temp, state.wrapX,
166170
config.dim(), config.rmsNormEps(), state.localSize);
@@ -174,7 +178,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
174178
unifiedLayer.task("attn_rms_apply",
175179
TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
176180
context, state.wrapXb, state.wrapX,
177-
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
181+
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);*/
178182

179183
// QKV Projection (fused with Q8 dequantization)
180184
unifiedLayer.task("qkv_projection",
@@ -306,7 +310,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
306310
// --- Attention Block ---
307311
// RMS Normalization
308312
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
309-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
313+
//tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
310314
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker);
311315
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
312316
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);

0 commit comments

Comments
 (0)