Skip to content

Commit 574a0d8

Browse files
committed
refactor(npfunc): replace ~400 NPTypeCode switch cases with NpFunc generic dispatch
NpFunc is a reflection-cached generic dispatch utility that bridges runtime NPTypeCode values to compile-time generic type parameters. Hot path (cache hit) runs at ~32ns via Delegate[] array indexed by NPTypeCode ordinal. Cold path uses MakeGenericMethod + CreateDelegate, cached after first call per (method, typeCode) pair. Core NpFunc changes: - Dynamic table sizing: Delegate[] sized from max NPTypeCode enum value (was hardcoded [32], broke for NPTypeCode.Complex=128) - Overloads for 0-6 args × void/returning × 1-3 NPTypeCodes + 1-2 Types - SmartMatchTypes for multi-type dispatch (1→broadcast, N=N→positional, M<N→type-identity matching) - Per-arity ConcurrentDictionary caches for multi-type dispatch Files refactored (12 files, ~400 cases eliminated): Previous session (5 files, ~196 cases): - Default.ClipNDArray.cs: 6 dispatch methods for contiguous/general clip - Default.Clip.cs: 3 dispatch methods for scalar clip with ChangeType - Default.NonZero.cs: 3 dispatch methods for nonzero/count operations - Default.BooleanMask.cs: 1 dispatch method for masked copy - Default.Shift.cs: 2 dispatch methods for array/scalar shift This session (7 files, ~202 cases): - NDIteratorExtensions.cs: 5 overloads → 5 dispatch methods creating NDIterator<T> from NDArray/UnmanagedStorage/IArraySlice - Default.Reduction.CumAdd.cs: axis dispatch via CumSumAxisKernel<T>, elementwise via IAdditionOperators<T,T,T> with default(T) init - Default.Reduction.CumMul.cs: axis dispatch via CumProdAxisKernel<T>, elementwise via IMultiplyOperators + T.MultiplicativeIdentity init - np.where.cs: iterator fallback + IL kernel dispatch via pointer cast - np.random.randint.cs: int/long fill via INumberBase<T>.CreateTruncating - NDArray.NOT.cs: IEquatable<T>.Equals(default) unifies bool NOT and numeric ==0 comparison into single generic method - Default.LogicalReduction.cs: direct dispatch to ExecuteLogicalAxis<T> Net: -1243 lines removed across 12 files, replacing repetitive per-type switch cases with single generic dispatch methods.
1 parent d364e7f commit 574a0d8

14 files changed

Lines changed: 591 additions & 1834 deletions

File tree

src/NumSharp.Core/APIs/np.where.cs

Lines changed: 7 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using NumSharp.Backends.Iteration;
33
using NumSharp.Backends.Kernels;
44
using NumSharp.Generic;
5+
using NumSharp.Utilities;
56

67
namespace NumSharp
78
{
@@ -120,47 +121,7 @@ private static NDArray where_internal(NDArray condition, NDArray x, NDArray y)
120121
}
121122

122123
// Iterator fallback for non-contiguous/broadcasted arrays
123-
switch (outType)
124-
{
125-
case NPTypeCode.Boolean:
126-
WhereImpl<bool>(cond, xArr, yArr, result);
127-
break;
128-
case NPTypeCode.Byte:
129-
WhereImpl<byte>(cond, xArr, yArr, result);
130-
break;
131-
case NPTypeCode.Int16:
132-
WhereImpl<short>(cond, xArr, yArr, result);
133-
break;
134-
case NPTypeCode.UInt16:
135-
WhereImpl<ushort>(cond, xArr, yArr, result);
136-
break;
137-
case NPTypeCode.Int32:
138-
WhereImpl<int>(cond, xArr, yArr, result);
139-
break;
140-
case NPTypeCode.UInt32:
141-
WhereImpl<uint>(cond, xArr, yArr, result);
142-
break;
143-
case NPTypeCode.Int64:
144-
WhereImpl<long>(cond, xArr, yArr, result);
145-
break;
146-
case NPTypeCode.UInt64:
147-
WhereImpl<ulong>(cond, xArr, yArr, result);
148-
break;
149-
case NPTypeCode.Char:
150-
WhereImpl<char>(cond, xArr, yArr, result);
151-
break;
152-
case NPTypeCode.Single:
153-
WhereImpl<float>(cond, xArr, yArr, result);
154-
break;
155-
case NPTypeCode.Double:
156-
WhereImpl<double>(cond, xArr, yArr, result);
157-
break;
158-
case NPTypeCode.Decimal:
159-
WhereImpl<decimal>(cond, xArr, yArr, result);
160-
break;
161-
default:
162-
throw new NotSupportedException($"Type {outType} not supported for np.where");
163-
}
124+
NpFunc.Invoke(outType, WhereImpl<int>, cond, xArr, yArr, result);
164125

165126
return result;
166127
}
@@ -200,50 +161,13 @@ private static void WhereImpl<T>(NDArray cond, NDArray x, NDArray y, NDArray res
200161
/// </summary>
201162
private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType)
202163
{
203-
var condPtr = (bool*)cond.Address;
164+
var condPtr = (nint)cond.Address;
204165
var count = result.size;
205166

206-
switch (outType)
207-
{
208-
case NPTypeCode.Boolean:
209-
ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count);
210-
break;
211-
case NPTypeCode.Byte:
212-
ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count);
213-
break;
214-
case NPTypeCode.Int16:
215-
ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count);
216-
break;
217-
case NPTypeCode.UInt16:
218-
ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count);
219-
break;
220-
case NPTypeCode.Int32:
221-
ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count);
222-
break;
223-
case NPTypeCode.UInt32:
224-
ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count);
225-
break;
226-
case NPTypeCode.Int64:
227-
ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count);
228-
break;
229-
case NPTypeCode.UInt64:
230-
ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count);
231-
break;
232-
case NPTypeCode.Char:
233-
ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count);
234-
break;
235-
case NPTypeCode.Single:
236-
ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count);
237-
break;
238-
case NPTypeCode.Double:
239-
ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count);
240-
break;
241-
case NPTypeCode.Decimal:
242-
ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count);
243-
break;
244-
default:
245-
throw new NotSupportedException($"Type {outType} not supported for np.where");
246-
}
167+
NpFunc.Invoke(outType, WhereKernelExecute<int>, condPtr, (nint)x.Address, (nint)y.Address, (nint)result.Address, count);
247168
}
169+
170+
private static unsafe void WhereKernelExecute<T>(nint condPtr, nint xAddr, nint yAddr, nint resultAddr, long count) where T : unmanaged
171+
=> ILKernelGenerator.WhereExecute((bool*)condPtr, (T*)xAddr, (T*)yAddr, (T*)resultAddr, count);
248172
}
249173
}

src/NumSharp.Core/Backends/Default/Indexing/Default.BooleanMask.cs

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
using NumSharp.Backends.Iteration;
33
using NumSharp.Backends.Kernels;
44
using NumSharp.Generic;
5+
using NumSharp.Utilities;
56

67
namespace NumSharp.Backends
78
{
89
public partial class DefaultEngine
910
{
11+
private static unsafe void CopyMaskedDispatch<T>(nint arr, nint mask, nint result, long size) where T : unmanaged
12+
=> ILKernelGenerator.CopyMaskedElementsHelper((T*)arr, (bool*)mask, (T*)result, size);
13+
1014
/// <summary>
1115
/// Apply a boolean mask to select elements from an array.
1216
/// </summary>
@@ -45,57 +49,7 @@ private unsafe NDArray BooleanMaskSimd(NDArray arr, NDArray<bool> mask)
4549
// Create result array
4650
var result = new NDArray(arr.dtype, new Shape(trueCount));
4751

48-
// Copy elements where mask is true
49-
switch (arr.typecode)
50-
{
51-
case NPTypeCode.Boolean:
52-
ILKernelGenerator.CopyMaskedElementsHelper((bool*)arr.Address, (bool*)mask.Address, (bool*)result.Address, size);
53-
break;
54-
case NPTypeCode.Byte:
55-
ILKernelGenerator.CopyMaskedElementsHelper((byte*)arr.Address, (bool*)mask.Address, (byte*)result.Address, size);
56-
break;
57-
case NPTypeCode.SByte:
58-
ILKernelGenerator.CopyMaskedElementsHelper((sbyte*)arr.Address, (bool*)mask.Address, (sbyte*)result.Address, size);
59-
break;
60-
case NPTypeCode.Int16:
61-
ILKernelGenerator.CopyMaskedElementsHelper((short*)arr.Address, (bool*)mask.Address, (short*)result.Address, size);
62-
break;
63-
case NPTypeCode.UInt16:
64-
ILKernelGenerator.CopyMaskedElementsHelper((ushort*)arr.Address, (bool*)mask.Address, (ushort*)result.Address, size);
65-
break;
66-
case NPTypeCode.Int32:
67-
ILKernelGenerator.CopyMaskedElementsHelper((int*)arr.Address, (bool*)mask.Address, (int*)result.Address, size);
68-
break;
69-
case NPTypeCode.UInt32:
70-
ILKernelGenerator.CopyMaskedElementsHelper((uint*)arr.Address, (bool*)mask.Address, (uint*)result.Address, size);
71-
break;
72-
case NPTypeCode.Int64:
73-
ILKernelGenerator.CopyMaskedElementsHelper((long*)arr.Address, (bool*)mask.Address, (long*)result.Address, size);
74-
break;
75-
case NPTypeCode.UInt64:
76-
ILKernelGenerator.CopyMaskedElementsHelper((ulong*)arr.Address, (bool*)mask.Address, (ulong*)result.Address, size);
77-
break;
78-
case NPTypeCode.Char:
79-
ILKernelGenerator.CopyMaskedElementsHelper((char*)arr.Address, (bool*)mask.Address, (char*)result.Address, size);
80-
break;
81-
case NPTypeCode.Half:
82-
ILKernelGenerator.CopyMaskedElementsHelper((Half*)arr.Address, (bool*)mask.Address, (Half*)result.Address, size);
83-
break;
84-
case NPTypeCode.Single:
85-
ILKernelGenerator.CopyMaskedElementsHelper((float*)arr.Address, (bool*)mask.Address, (float*)result.Address, size);
86-
break;
87-
case NPTypeCode.Double:
88-
ILKernelGenerator.CopyMaskedElementsHelper((double*)arr.Address, (bool*)mask.Address, (double*)result.Address, size);
89-
break;
90-
case NPTypeCode.Decimal:
91-
ILKernelGenerator.CopyMaskedElementsHelper((decimal*)arr.Address, (bool*)mask.Address, (decimal*)result.Address, size);
92-
break;
93-
case NPTypeCode.Complex:
94-
ILKernelGenerator.CopyMaskedElementsHelper((System.Numerics.Complex*)arr.Address, (bool*)mask.Address, (System.Numerics.Complex*)result.Address, size);
95-
break;
96-
default:
97-
throw new NotSupportedException($"Type {arr.typecode} not supported for boolean masking");
98-
}
52+
NpFunc.Invoke(arr.typecode, CopyMaskedDispatch<int>, (nint)arr.Address, (nint)mask.Address, (nint)result.Address, size);
9953

10054
return result;
10155
}

src/NumSharp.Core/Backends/Default/Indexing/Default.NonZero.cs

Lines changed: 13 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,24 @@
44
using NumSharp.Backends.Iteration;
55
using NumSharp.Backends.Kernels;
66
using NumSharp.Backends.Unmanaged;
7+
using NumSharp.Utilities;
78

89
namespace NumSharp.Backends
910
{
1011
public partial class DefaultEngine
1112
{
12-
/// <summary>
13-
/// Return the indices of non-zero elements.
14-
/// </summary>
15-
/// <remarks>
16-
/// NumPy-aligned behavior:
17-
/// - Returns tuple of arrays, one per dimension
18-
/// - For empty arrays, returns empty arrays with correct dtype (int)
19-
/// - Iterates in C-order (row-major)
20-
/// - Handles contiguous and strided arrays efficiently
21-
/// </remarks>
22-
/// <param name="nd">Input array</param>
23-
/// <returns>Array of NDArray&lt;long&gt;, one per dimension containing indices of non-zero elements</returns>
13+
private static NDArray<long>[] NonZeroDispatch<T>(NDArray nd) where T : unmanaged
14+
=> nonzeros<T>(nd.MakeGeneric<T>());
15+
16+
private static long CountNonZeroDispatch<T>(NDArray nd) where T : unmanaged
17+
=> count_nonzero<T>(nd.MakeGeneric<T>());
18+
19+
private static void CountNonZeroAxisDispatch<T>(NDArray nd, NDArray result, int axis) where T : unmanaged
20+
=> count_nonzero_axis<T>(nd.MakeGeneric<T>(), result, axis);
21+
2422
public override NDArray<long>[] NonZero(NDArray nd)
2523
{
26-
// Type dispatch to generic implementation
27-
switch (nd.typecode)
28-
{
29-
case NPTypeCode.Boolean: return nonzeros<bool>(nd.MakeGeneric<bool>());
30-
case NPTypeCode.Byte: return nonzeros<byte>(nd.MakeGeneric<byte>());
31-
case NPTypeCode.SByte: return nonzeros<sbyte>(nd.MakeGeneric<sbyte>());
32-
case NPTypeCode.Int16: return nonzeros<short>(nd.MakeGeneric<short>());
33-
case NPTypeCode.UInt16: return nonzeros<ushort>(nd.MakeGeneric<ushort>());
34-
case NPTypeCode.Int32: return nonzeros<int>(nd.MakeGeneric<int>());
35-
case NPTypeCode.UInt32: return nonzeros<uint>(nd.MakeGeneric<uint>());
36-
case NPTypeCode.Int64: return nonzeros<long>(nd.MakeGeneric<long>());
37-
case NPTypeCode.UInt64: return nonzeros<ulong>(nd.MakeGeneric<ulong>());
38-
case NPTypeCode.Char: return nonzeros<char>(nd.MakeGeneric<char>());
39-
case NPTypeCode.Half: return nonzeros<Half>(nd.MakeGeneric<Half>());
40-
case NPTypeCode.Double: return nonzeros<double>(nd.MakeGeneric<double>());
41-
case NPTypeCode.Single: return nonzeros<float>(nd.MakeGeneric<float>());
42-
case NPTypeCode.Decimal: return nonzeros<decimal>(nd.MakeGeneric<decimal>());
43-
case NPTypeCode.Complex: return nonzeros<System.Numerics.Complex>(nd.MakeGeneric<System.Numerics.Complex>());
44-
default:
45-
throw new NotSupportedException($"NonZero not supported for type {nd.typecode}");
46-
}
24+
return NpFunc.Invoke(nd.typecode, NonZeroDispatch<int>, nd);
4725
}
4826

4927
/// <summary>
@@ -84,27 +62,7 @@ public override long CountNonZero(NDArray nd)
8462
if (nd.size == 0)
8563
return 0;
8664

87-
// Type dispatch to generic implementation
88-
switch (nd.typecode)
89-
{
90-
case NPTypeCode.Boolean: return count_nonzero<bool>(nd.MakeGeneric<bool>());
91-
case NPTypeCode.Byte: return count_nonzero<byte>(nd.MakeGeneric<byte>());
92-
case NPTypeCode.SByte: return count_nonzero<sbyte>(nd.MakeGeneric<sbyte>());
93-
case NPTypeCode.Int16: return count_nonzero<short>(nd.MakeGeneric<short>());
94-
case NPTypeCode.UInt16: return count_nonzero<ushort>(nd.MakeGeneric<ushort>());
95-
case NPTypeCode.Int32: return count_nonzero<int>(nd.MakeGeneric<int>());
96-
case NPTypeCode.UInt32: return count_nonzero<uint>(nd.MakeGeneric<uint>());
97-
case NPTypeCode.Int64: return count_nonzero<long>(nd.MakeGeneric<long>());
98-
case NPTypeCode.UInt64: return count_nonzero<ulong>(nd.MakeGeneric<ulong>());
99-
case NPTypeCode.Char: return count_nonzero<char>(nd.MakeGeneric<char>());
100-
case NPTypeCode.Half: return count_nonzero<Half>(nd.MakeGeneric<Half>());
101-
case NPTypeCode.Double: return count_nonzero<double>(nd.MakeGeneric<double>());
102-
case NPTypeCode.Single: return count_nonzero<float>(nd.MakeGeneric<float>());
103-
case NPTypeCode.Decimal: return count_nonzero<decimal>(nd.MakeGeneric<decimal>());
104-
case NPTypeCode.Complex: return count_nonzero<System.Numerics.Complex>(nd.MakeGeneric<System.Numerics.Complex>());
105-
default:
106-
throw new NotSupportedException($"CountNonZero not supported for type {nd.typecode}");
107-
}
65+
return NpFunc.Invoke(nd.typecode, CountNonZeroDispatch<int>, nd);
10866
}
10967

11068
/// <summary>
@@ -141,27 +99,7 @@ public override NDArray CountNonZero(NDArray nd, int axis, bool keepdims = false
14199
return result;
142100
}
143101

144-
// Type dispatch
145-
switch (nd.typecode)
146-
{
147-
case NPTypeCode.Boolean: count_nonzero_axis<bool>(nd.MakeGeneric<bool>(), result, axis); break;
148-
case NPTypeCode.Byte: count_nonzero_axis<byte>(nd.MakeGeneric<byte>(), result, axis); break;
149-
case NPTypeCode.SByte: count_nonzero_axis<sbyte>(nd.MakeGeneric<sbyte>(), result, axis); break;
150-
case NPTypeCode.Int16: count_nonzero_axis<short>(nd.MakeGeneric<short>(), result, axis); break;
151-
case NPTypeCode.UInt16: count_nonzero_axis<ushort>(nd.MakeGeneric<ushort>(), result, axis); break;
152-
case NPTypeCode.Int32: count_nonzero_axis<int>(nd.MakeGeneric<int>(), result, axis); break;
153-
case NPTypeCode.UInt32: count_nonzero_axis<uint>(nd.MakeGeneric<uint>(), result, axis); break;
154-
case NPTypeCode.Int64: count_nonzero_axis<long>(nd.MakeGeneric<long>(), result, axis); break;
155-
case NPTypeCode.UInt64: count_nonzero_axis<ulong>(nd.MakeGeneric<ulong>(), result, axis); break;
156-
case NPTypeCode.Char: count_nonzero_axis<char>(nd.MakeGeneric<char>(), result, axis); break;
157-
case NPTypeCode.Half: count_nonzero_axis<Half>(nd.MakeGeneric<Half>(), result, axis); break;
158-
case NPTypeCode.Double: count_nonzero_axis<double>(nd.MakeGeneric<double>(), result, axis); break;
159-
case NPTypeCode.Single: count_nonzero_axis<float>(nd.MakeGeneric<float>(), result, axis); break;
160-
case NPTypeCode.Decimal: count_nonzero_axis<decimal>(nd.MakeGeneric<decimal>(), result, axis); break;
161-
case NPTypeCode.Complex: count_nonzero_axis<System.Numerics.Complex>(nd.MakeGeneric<System.Numerics.Complex>(), result, axis); break;
162-
default:
163-
throw new NotSupportedException($"CountNonZero not supported for type {nd.typecode}");
164-
}
102+
NpFunc.Invoke(nd.typecode, CountNonZeroAxisDispatch<int>, nd, result, axis);
165103

166104
if (keepdims)
167105
{

src/NumSharp.Core/Backends/Default/Logic/Default.LogicalReduction.cs

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using NumSharp.Backends.Iteration;
33
using NumSharp.Generic;
4+
using NumSharp.Utilities;
45

56
namespace NumSharp.Backends
67
{
@@ -33,47 +34,7 @@ private NDArray<bool> ReduceLogicalAxis(NDArray nd, int axis, bool keepdims, boo
3334
if (result.size == 0 || nd.Shape.dimensions[axis] == 0)
3435
return result;
3536

36-
switch (nd.GetTypeCode)
37-
{
38-
case NPTypeCode.Boolean:
39-
ExecuteLogicalAxis<bool>(nd, result, axis, reduceAll);
40-
break;
41-
case NPTypeCode.Byte:
42-
ExecuteLogicalAxis<byte>(nd, result, axis, reduceAll);
43-
break;
44-
case NPTypeCode.Int16:
45-
ExecuteLogicalAxis<short>(nd, result, axis, reduceAll);
46-
break;
47-
case NPTypeCode.UInt16:
48-
ExecuteLogicalAxis<ushort>(nd, result, axis, reduceAll);
49-
break;
50-
case NPTypeCode.Int32:
51-
ExecuteLogicalAxis<int>(nd, result, axis, reduceAll);
52-
break;
53-
case NPTypeCode.UInt32:
54-
ExecuteLogicalAxis<uint>(nd, result, axis, reduceAll);
55-
break;
56-
case NPTypeCode.Int64:
57-
ExecuteLogicalAxis<long>(nd, result, axis, reduceAll);
58-
break;
59-
case NPTypeCode.UInt64:
60-
ExecuteLogicalAxis<ulong>(nd, result, axis, reduceAll);
61-
break;
62-
case NPTypeCode.Char:
63-
ExecuteLogicalAxis<char>(nd, result, axis, reduceAll);
64-
break;
65-
case NPTypeCode.Single:
66-
ExecuteLogicalAxis<float>(nd, result, axis, reduceAll);
67-
break;
68-
case NPTypeCode.Double:
69-
ExecuteLogicalAxis<double>(nd, result, axis, reduceAll);
70-
break;
71-
case NPTypeCode.Decimal:
72-
ExecuteLogicalAxis<decimal>(nd, result, axis, reduceAll);
73-
break;
74-
default:
75-
throw new NotSupportedException($"Type {nd.GetTypeCode} not supported for logical reduction.");
76-
}
37+
NpFunc.Invoke(nd.GetTypeCode, ExecuteLogicalAxis<int>, nd, result, axis, reduceAll);
7738

7839
return result;
7940
}

0 commit comments

Comments
 (0)