Skip to content

Commit e64c1f4

Browse files
committed
perf: 4x loop unrolling for SIMD kernels
Add 4x unrolled SIMD loops to ILKernelGenerator and SimdKernels: - Binary ops (Add, Sub, Mul, Div, etc.) - Comparison ops (==, !=, <, >, <=, >=) - Mixed-type ops - Reduction ops - Unary ops - SimdKernels.cs (C# fallback paths) Loop structure: 1. 4x unrolled main loop: process 4 vectors per iteration 2. Remainder loop: process 0-3 remaining vectors 3. Scalar tail: handle non-vector-aligned elements This reduces loop overhead and improves instruction-level parallelism by allowing the CPU to pipeline more operations.
1 parent f4e423f commit e64c1f4

6 files changed

Lines changed: 785 additions & 93 deletions

File tree

src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Binary.cs

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,30 +228,108 @@ private static unsafe ContiguousKernel<T> GenerateContiguousKernelIL<T>(BinaryOp
228228

229229
if (!isScalarOnly)
230230
{
231-
// SIMD-capable operations: generate SIMD loop + tail loop
232-
var locVectorEnd = il.DeclareLocal(typeof(int)); // totalSize - vectorCount
231+
// SIMD-capable operations: generate 4x unrolled SIMD loop + remainder + tail loop
232+
var locVectorEnd = il.DeclareLocal(typeof(int)); // count - vectorCount (for remainder loop)
233+
var locUnrollEnd = il.DeclareLocal(typeof(int)); // count - vectorCount*4 (for 4x unrolled loop)
233234

234235
// Define labels
235-
var lblSimdLoop = il.DefineLabel();
236-
var lblSimdLoopEnd = il.DefineLabel();
236+
var lblUnrollLoop = il.DefineLabel();
237+
var lblUnrollLoopEnd = il.DefineLabel();
238+
var lblRemainderLoop = il.DefineLabel();
239+
var lblRemainderLoopEnd = il.DefineLabel();
237240
var lblTailLoop = il.DefineLabel();
238241
var lblTailLoopEnd = il.DefineLabel();
239242

240243
int vectorCount = GetVectorCount<T>();
244+
int unrollStep = vectorCount * 4;
241245

242-
// vectorEnd = count - vectorCount
246+
// vectorEnd = count - vectorCount (for remainder loop)
243247
il.Emit(OpCodes.Ldarg_3); // count
244248
il.Emit(OpCodes.Ldc_I4, vectorCount);
245249
il.Emit(OpCodes.Sub);
246250
il.Emit(OpCodes.Stloc, locVectorEnd);
247251

248-
// ========== SIMD LOOP ==========
249-
il.MarkLabel(lblSimdLoop);
252+
// unrollEnd = count - vectorCount*4 (for 4x unrolled loop)
253+
il.Emit(OpCodes.Ldarg_3); // count
254+
il.Emit(OpCodes.Ldc_I4, unrollStep);
255+
il.Emit(OpCodes.Sub);
256+
il.Emit(OpCodes.Stloc, locUnrollEnd);
257+
258+
// ========== 4x UNROLLED SIMD LOOP ==========
259+
il.MarkLabel(lblUnrollLoop);
260+
261+
// if (i > unrollEnd) goto UnrollLoopEnd
262+
il.Emit(OpCodes.Ldloc, locI);
263+
il.Emit(OpCodes.Ldloc, locUnrollEnd);
264+
il.Emit(OpCodes.Bgt, lblUnrollLoopEnd);
265+
266+
// Process 4 vectors per iteration
267+
for (int u = 0; u < 4; u++)
268+
{
269+
int offset = vectorCount * u;
270+
271+
// Load lhs vector at (i + offset)
272+
il.Emit(OpCodes.Ldarg_0); // lhs
273+
il.Emit(OpCodes.Ldloc, locI);
274+
if (offset > 0)
275+
{
276+
il.Emit(OpCodes.Ldc_I4, offset);
277+
il.Emit(OpCodes.Add);
278+
}
279+
il.Emit(OpCodes.Conv_I);
280+
il.Emit(OpCodes.Ldc_I4, elementSize);
281+
il.Emit(OpCodes.Mul);
282+
il.Emit(OpCodes.Add);
283+
EmitVectorLoad<T>(il);
284+
285+
// Load rhs vector at (i + offset)
286+
il.Emit(OpCodes.Ldarg_1); // rhs
287+
il.Emit(OpCodes.Ldloc, locI);
288+
if (offset > 0)
289+
{
290+
il.Emit(OpCodes.Ldc_I4, offset);
291+
il.Emit(OpCodes.Add);
292+
}
293+
il.Emit(OpCodes.Conv_I);
294+
il.Emit(OpCodes.Ldc_I4, elementSize);
295+
il.Emit(OpCodes.Mul);
296+
il.Emit(OpCodes.Add);
297+
EmitVectorLoad<T>(il);
298+
299+
// Perform vector operation
300+
EmitVectorOperation<T>(il, op);
301+
302+
// Store result at (i + offset)
303+
il.Emit(OpCodes.Ldarg_2); // result
304+
il.Emit(OpCodes.Ldloc, locI);
305+
if (offset > 0)
306+
{
307+
il.Emit(OpCodes.Ldc_I4, offset);
308+
il.Emit(OpCodes.Add);
309+
}
310+
il.Emit(OpCodes.Conv_I);
311+
il.Emit(OpCodes.Ldc_I4, elementSize);
312+
il.Emit(OpCodes.Mul);
313+
il.Emit(OpCodes.Add);
314+
EmitVectorStore<T>(il);
315+
}
316+
317+
// i += vectorCount * 4
318+
il.Emit(OpCodes.Ldloc, locI);
319+
il.Emit(OpCodes.Ldc_I4, unrollStep);
320+
il.Emit(OpCodes.Add);
321+
il.Emit(OpCodes.Stloc, locI);
322+
323+
il.Emit(OpCodes.Br, lblUnrollLoop);
324+
il.MarkLabel(lblUnrollLoopEnd);
325+
326+
// ========== REMAINDER SIMD LOOP (0-3 vectors) ==========
327+
il.MarkLabel(lblRemainderLoop);
250328

251-
// if (i > vectorEnd) goto SimdLoopEnd
329+
// if (i > vectorEnd) goto RemainderLoopEnd
252330
il.Emit(OpCodes.Ldloc, locI);
253331
il.Emit(OpCodes.Ldloc, locVectorEnd);
254-
il.Emit(OpCodes.Bgt, lblSimdLoopEnd);
332+
il.Emit(OpCodes.Bgt, lblRemainderLoopEnd);
255333

256334
// Load lhs vector: Vector256.Load(lhs + i)
257335
il.Emit(OpCodes.Ldarg_0); // lhs
@@ -289,10 +367,10 @@ private static unsafe ContiguousKernel<T> GenerateContiguousKernelIL<T>(BinaryOp
289367
il.Emit(OpCodes.Add);
290368
il.Emit(OpCodes.Stloc, locI);
291369

292-
il.Emit(OpCodes.Br, lblSimdLoop);
293-
il.MarkLabel(lblSimdLoopEnd);
370+
il.Emit(OpCodes.Br, lblRemainderLoop);
371+
il.MarkLabel(lblRemainderLoopEnd);
294372

295-
// ========== TAIL LOOP ==========
373+
// ========== TAIL LOOP (scalar) ==========
296374
il.MarkLabel(lblTailLoop);
297375

298376
// if (i >= count) goto TailLoopEnd

src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Comparison.cs

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,14 @@ private static ComparisonKernel GenerateComparisonGeneralKernel(ComparisonKernel
306306
#region Comparison Loop Emission
307307

308308
/// <summary>
309-
/// Emit a SIMD loop for contiguous comparison (adapts to V128/V256/V512).
309+
/// Emit a SIMD loop for contiguous comparison with 4x unrolling (adapts to V128/V256/V512).
310310
/// </summary>
311311
private static void EmitComparisonSimdLoop(ILGenerator il, ComparisonKernelKey key,
312312
int lhsSize, int rhsSize, NPTypeCode comparisonType)
313313
{
314314
int vectorCount = GetVectorCount(comparisonType);
315+
int unrollFactor = 4;
316+
int unrollStep = vectorCount * unrollFactor;
315317
var clrType = GetClrType(comparisonType);
316318
var vectorType = GetVectorType(clrType);
317319

@@ -320,8 +322,21 @@ private static void EmitComparisonSimdLoop(ILGenerator il, ComparisonKernelKey k
320322
// int ndim (6), int totalSize (7)
321323

322324
var locI = il.DeclareLocal(typeof(int));
325+
var locUnrollEnd = il.DeclareLocal(typeof(int));
323326
var locVectorEnd = il.DeclareLocal(typeof(int));
324-
var locMask = il.DeclareLocal(vectorType);
327+
328+
// Declare mask locals for 4x unrolling
329+
var locMask0 = il.DeclareLocal(vectorType);
330+
var locMask1 = il.DeclareLocal(vectorType);
331+
var locMask2 = il.DeclareLocal(vectorType);
332+
var locMask3 = il.DeclareLocal(vectorType);
333+
var maskLocals = new[] { locMask0, locMask1, locMask2, locMask3 };
334+
335+
// unrollEnd = totalSize - unrollStep + 1 (last valid 4x start position)
336+
il.Emit(OpCodes.Ldarg_S, (byte)7); // totalSize
337+
il.Emit(OpCodes.Ldc_I4, unrollStep - 1);
338+
il.Emit(OpCodes.Sub);
339+
il.Emit(OpCodes.Stloc, locUnrollEnd);
325340

326341
// vectorEnd = totalSize - vectorCount + 1 (last valid SIMD start position)
327342
il.Emit(OpCodes.Ldarg_S, (byte)7); // totalSize
@@ -333,16 +348,88 @@ private static void EmitComparisonSimdLoop(ILGenerator il, ComparisonKernelKey k
333348
il.Emit(OpCodes.Ldc_I4_0);
334349
il.Emit(OpCodes.Stloc, locI);
335350

336-
var lblSimdLoop = il.DefineLabel();
337-
var lblSimdEnd = il.DefineLabel();
351+
var lblUnrollLoop = il.DefineLabel();
352+
var lblUnrollEnd = il.DefineLabel();
353+
var lblRemainderLoop = il.DefineLabel();
354+
var lblRemainderEnd = il.DefineLabel();
338355
var lblTailLoop = il.DefineLabel();
339356
var lblTailEnd = il.DefineLabel();
340357

341-
// === SIMD Loop ===
342-
il.MarkLabel(lblSimdLoop);
358+
// === 4x UNROLLED SIMD LOOP ===
359+
il.MarkLabel(lblUnrollLoop);
360+
il.Emit(OpCodes.Ldloc, locI);
361+
il.Emit(OpCodes.Ldloc, locUnrollEnd);
362+
il.Emit(OpCodes.Bgt, lblUnrollEnd);
363+
364+
// Load 4 lhs vectors, 4 rhs vectors, compare, store masks
365+
for (int n = 0; n < unrollFactor; n++)
366+
{
367+
int offset = n * vectorCount;
368+
369+
// Load lhs vector at (i + offset) * lhsSize
370+
il.Emit(OpCodes.Ldarg_0); // lhs
371+
il.Emit(OpCodes.Ldloc, locI);
372+
if (offset > 0)
373+
{
374+
il.Emit(OpCodes.Ldc_I4, offset);
375+
il.Emit(OpCodes.Add);
376+
}
377+
il.Emit(OpCodes.Conv_I);
378+
il.Emit(OpCodes.Ldc_I4, lhsSize);
379+
il.Emit(OpCodes.Mul);
380+
il.Emit(OpCodes.Add);
381+
EmitVectorLoad(il, comparisonType);
382+
383+
// Load rhs vector at (i + offset) * rhsSize
384+
il.Emit(OpCodes.Ldarg_1); // rhs
385+
il.Emit(OpCodes.Ldloc, locI);
386+
if (offset > 0)
387+
{
388+
il.Emit(OpCodes.Ldc_I4, offset);
389+
il.Emit(OpCodes.Add);
390+
}
391+
il.Emit(OpCodes.Conv_I);
392+
il.Emit(OpCodes.Ldc_I4, rhsSize);
393+
il.Emit(OpCodes.Mul);
394+
il.Emit(OpCodes.Add);
395+
EmitVectorLoad(il, comparisonType);
396+
397+
// Compare: produces mask vector
398+
EmitVectorComparison(il, key.Op, comparisonType);
399+
il.Emit(OpCodes.Stloc, maskLocals[n]);
400+
}
401+
402+
// Extract all 4 masks to booleans
403+
for (int n = 0; n < unrollFactor; n++)
404+
{
405+
int offset = n * vectorCount;
406+
407+
// Create a temporary local to hold (i + offset) for extraction
408+
var locIOffset = il.DeclareLocal(typeof(int));
409+
il.Emit(OpCodes.Ldloc, locI);
410+
if (offset > 0)
411+
{
412+
il.Emit(OpCodes.Ldc_I4, offset);
413+
il.Emit(OpCodes.Add);
414+
}
415+
il.Emit(OpCodes.Stloc, locIOffset);
416+
417+
EmitMaskToBoolExtraction(il, comparisonType, vectorCount, locIOffset, maskLocals[n]);
418+
}
419+
420+
// i += unrollStep
421+
il.Emit(OpCodes.Ldloc, locI);
422+
il.Emit(OpCodes.Ldc_I4, unrollStep);
423+
il.Emit(OpCodes.Add);
424+
il.Emit(OpCodes.Stloc, locI);
425+
il.Emit(OpCodes.Br, lblUnrollLoop);
426+
427+
// === REMAINDER SIMD LOOP (0-3 vectors) ===
428+
il.MarkLabel(lblUnrollEnd);
429+
il.MarkLabel(lblRemainderLoop);
343430
il.Emit(OpCodes.Ldloc, locI);
344431
il.Emit(OpCodes.Ldloc, locVectorEnd);
345-
il.Emit(OpCodes.Bgt, lblSimdEnd);
432+
il.Emit(OpCodes.Bgt, lblRemainderEnd);
346433

347434
// Load lhs vector: lhs + i * elemSize
348435
il.Emit(OpCodes.Ldarg_0); // lhs
@@ -364,20 +451,20 @@ private static void EmitComparisonSimdLoop(ILGenerator il, ComparisonKernelKey k
364451

365452
// Compare: produces mask vector
366453
EmitVectorComparison(il, key.Op, comparisonType);
367-
il.Emit(OpCodes.Stloc, locMask);
454+
il.Emit(OpCodes.Stloc, locMask0);
368455

369456
// Extract mask to booleans
370-
EmitMaskToBoolExtraction(il, comparisonType, vectorCount, locI, locMask);
457+
EmitMaskToBoolExtraction(il, comparisonType, vectorCount, locI, locMask0);
371458

372459
// i += vectorCount
373460
il.Emit(OpCodes.Ldloc, locI);
374461
il.Emit(OpCodes.Ldc_I4, vectorCount);
375462
il.Emit(OpCodes.Add);
376463
il.Emit(OpCodes.Stloc, locI);
377-
il.Emit(OpCodes.Br, lblSimdLoop);
464+
il.Emit(OpCodes.Br, lblRemainderLoop);
378465

379-
// === Tail Loop (scalar) ===
380-
il.MarkLabel(lblSimdEnd);
466+
// === SCALAR TAIL LOOP ===
467+
il.MarkLabel(lblRemainderEnd);
381468
il.MarkLabel(lblTailLoop);
382469

383470
// if (i >= totalSize) goto end

0 commit comments

Comments
 (0)