Skip to content

Commit 8422058

Browse files
authored
Merge pull request #84 from beehive-lab/hotfix/addnorm-non-nvidia
[fix] Normalization compute step for non-nvidia hardware
2 parents 9925e01 + 7b830ea commit 8422058

File tree

5 files changed

+42
-0
lines changed

5 files changed

+42
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
234234
phi3Config.rmsNormEps(), // epsilon
235235
phi3State.localSize); // local memory size
236236

237+
if (shouldUseFinalNormalization()) {
238+
unifiedLayer.task("attn_rms_finalize",
239+
TransformerComputeKernelsLayered::reductionFinalNormalization,
240+
context, state.temp, config.dim(), config.rmsNormEps());
241+
}
242+
237243
unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect,
238244
context, phi3State.wrapX, // input
239245
phi3State.wrapQ, // output Q

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,15 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex)
271271
config.rmsNormEps(), // epsilon
272272
qwen2State.localSize); // local memory size
273273

274+
if (shouldUseFinalNormalization()) {
275+
unifiedLayer.task("attn_rms_finalize",
276+
TransformerComputeKernelsLayered::reductionFinalNormalization,
277+
context,
278+
state.temp,
279+
config.dim(),
280+
config.rmsNormEps());
281+
}
282+
274283
// Fused RMS Apply + QKV Projection
275284
unifiedLayer.task("attn_rms_qkv_projection",
276285
Qwen3Kernels::fusedRmsNormQKVMatmul,

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
264264
qwen3Config.rmsNormEps(), // epsilon
265265
qwen3State.localSize); // local memory size
266266

267+
if (shouldUseFinalNormalization()) {
268+
unifiedLayer.task("attn_rms_finalize",
269+
TransformerComputeKernelsLayered::reductionFinalNormalization,
270+
context,
271+
state.temp,
272+
config.dim(),
273+
config.rmsNormEps());
274+
}
275+
267276
// Fused RMS Apply + QKV Projection
268277
unifiedLayer.task("attn_rms_qkv_projection",
269278
Qwen3Kernels::fusedRmsNormQKVMatmul,

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,15 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex
236236
phi3Config.rmsNormEps(), // epsilon
237237
phi3State.localSize); // local memory size
238238

239+
if (shouldUseFinalNormalization()) {
240+
unifiedLayer.task("attn_rms_finalize",
241+
TransformerComputeKernelsLayered::reductionFinalNormalization,
242+
context,
243+
state.temp,
244+
config.dim(),
245+
config.rmsNormEps());
246+
}
247+
239248
// Fused: RMS apply + Q8 QKV matmul + direct Q/K/V split
240249
unifiedLayer.task("attn_rms_qkv_projection_q8",
241250
TransformerComputeKernelsLayered::fusedRmsNormQKVMatmulQ8,

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
190190
config.rmsNormEps(), // epsilon
191191
qwen3State.localSize); // local memory size
192192

193+
if (shouldUseFinalNormalization()) {
194+
unifiedLayer.task("attn_rms_finalize",
195+
TransformerComputeKernelsLayered::reductionFinalNormalization,
196+
context,
197+
state.temp,
198+
config.dim(),
199+
config.rmsNormEps());
200+
}
201+
193202
// Fused RMS Apply + QKV Projection
194203
unifiedLayer.task("attn_rms_qkv_projection",
195204
Qwen3Kernels::fusedRmsNormQKVMatmulQ8_0,

0 commit comments

Comments
 (0)