Skip to content

Commit 6079495

Browse files
authored
Merge pull request #606 from SciSharp/np_where
[API] Support np.where via ILKernelGenerator
2 parents 7021008 + a5862bd commit 6079495

8 files changed

Lines changed: 3841 additions & 5 deletions

File tree

src/NumSharp.Core/APIs/np.where.cs

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
using System;
2+
using NumSharp.Backends.Kernels;
3+
using NumSharp.Generic;
4+
5+
namespace NumSharp
6+
{
7+
public static partial class np
8+
{
9+
/// <summary>
10+
/// Equivalent to <see cref="nonzero(NDArray)"/>: returns the indices where
11+
/// <paramref name="condition"/> is non-zero.
12+
/// </summary>
13+
/// <param name="condition">Input array. Non-zero entries yield their indices.</param>
14+
/// <returns>Tuple of arrays with indices where condition is non-zero, one per dimension.</returns>
15+
/// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks>
16+
public static NDArray<long>[] where(NDArray condition)
17+
{
18+
return nonzero(condition);
19+
}
20+
21+
/// <summary>
22+
/// Return elements chosen from `x` or `y` depending on `condition`.
23+
/// </summary>
24+
/// <param name="condition">Where True, yield `x`, otherwise yield `y`.</param>
25+
/// <param name="x">Values from which to choose where condition is True.</param>
26+
/// <param name="y">Values from which to choose where condition is False.</param>
27+
/// <returns>An array with elements from `x` where `condition` is True, and elements from `y` elsewhere.</returns>
28+
/// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks>
29+
public static NDArray where(NDArray condition, NDArray x, NDArray y)
30+
{
31+
return where_internal(condition, x, y);
32+
}
33+
34+
/// <summary>
35+
/// Return elements chosen from `x` or `y` depending on `condition`.
36+
/// Scalar overload for x.
37+
/// </summary>
38+
public static NDArray where(NDArray condition, object x, NDArray y)
39+
{
40+
return where_internal(condition, asanyarray(x), y);
41+
}
42+
43+
/// <summary>
44+
/// Return elements chosen from `x` or `y` depending on `condition`.
45+
/// Scalar overload for y.
46+
/// </summary>
47+
public static NDArray where(NDArray condition, NDArray x, object y)
48+
{
49+
return where_internal(condition, x, asanyarray(y));
50+
}
51+
52+
/// <summary>
53+
/// Return elements chosen from `x` or `y` depending on `condition`.
54+
/// Scalar overload for both x and y.
55+
/// </summary>
56+
public static NDArray where(NDArray condition, object x, object y)
57+
{
58+
return where_internal(condition, asanyarray(x), asanyarray(y));
59+
}
60+
61+
/// <summary>
62+
/// Internal implementation of np.where.
63+
/// </summary>
64+
private static NDArray where_internal(NDArray condition, NDArray x, NDArray y)
65+
{
66+
// Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three
67+
// already share a shape — the frequent case of np.where(mask, arr, other_arr).
68+
NDArray cond, xArr, yArr;
69+
if (condition.Shape == x.Shape && x.Shape == y.Shape)
70+
{
71+
cond = condition;
72+
xArr = x;
73+
yArr = y;
74+
}
75+
else
76+
{
77+
var broadcasted = broadcast_arrays(condition, x, y);
78+
cond = broadcasted[0];
79+
xArr = broadcasted[1];
80+
yArr = broadcasted[2];
81+
}
82+
83+
// When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to
84+
// _FindCommonType which handles the scalar+array NEP50 rules.
85+
var outType = x.GetTypeCode == y.GetTypeCode
86+
? x.GetTypeCode
87+
: _FindCommonType(x, y);
88+
89+
if (xArr.GetTypeCode != outType)
90+
xArr = xArr.astype(outType, copy: false);
91+
if (yArr.GetTypeCode != outType)
92+
yArr = yArr.astype(outType, copy: false);
93+
94+
// Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides)
95+
var result = empty(cond.shape, outType);
96+
97+
// Handle empty arrays - nothing to iterate
98+
if (result.size == 0)
99+
return result;
100+
101+
// IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled
102+
// Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path.
103+
bool canUseKernel = ILKernelGenerator.Enabled &&
104+
cond.typecode == NPTypeCode.Boolean &&
105+
cond.Shape.IsContiguous &&
106+
xArr.Shape.IsContiguous &&
107+
yArr.Shape.IsContiguous;
108+
109+
if (canUseKernel)
110+
{
111+
WhereKernelDispatch(cond, xArr, yArr, result, outType);
112+
return result;
113+
}
114+
115+
// Iterator fallback for non-contiguous/broadcasted arrays
116+
switch (outType)
117+
{
118+
case NPTypeCode.Boolean:
119+
WhereImpl<bool>(cond, xArr, yArr, result);
120+
break;
121+
case NPTypeCode.Byte:
122+
WhereImpl<byte>(cond, xArr, yArr, result);
123+
break;
124+
case NPTypeCode.Int16:
125+
WhereImpl<short>(cond, xArr, yArr, result);
126+
break;
127+
case NPTypeCode.UInt16:
128+
WhereImpl<ushort>(cond, xArr, yArr, result);
129+
break;
130+
case NPTypeCode.Int32:
131+
WhereImpl<int>(cond, xArr, yArr, result);
132+
break;
133+
case NPTypeCode.UInt32:
134+
WhereImpl<uint>(cond, xArr, yArr, result);
135+
break;
136+
case NPTypeCode.Int64:
137+
WhereImpl<long>(cond, xArr, yArr, result);
138+
break;
139+
case NPTypeCode.UInt64:
140+
WhereImpl<ulong>(cond, xArr, yArr, result);
141+
break;
142+
case NPTypeCode.Char:
143+
WhereImpl<char>(cond, xArr, yArr, result);
144+
break;
145+
case NPTypeCode.Single:
146+
WhereImpl<float>(cond, xArr, yArr, result);
147+
break;
148+
case NPTypeCode.Double:
149+
WhereImpl<double>(cond, xArr, yArr, result);
150+
break;
151+
case NPTypeCode.Decimal:
152+
WhereImpl<decimal>(cond, xArr, yArr, result);
153+
break;
154+
default:
155+
throw new NotSupportedException($"Type {outType} not supported for np.where");
156+
}
157+
158+
return result;
159+
}
160+
161+
private static void WhereImpl<T>(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged
162+
{
163+
// Use iterators for proper handling of broadcasted/strided arrays
164+
using var condIter = cond.AsIterator<bool>();
165+
using var xIter = x.AsIterator<T>();
166+
using var yIter = y.AsIterator<T>();
167+
using var resultIter = result.AsIterator<T>();
168+
169+
while (condIter.HasNext())
170+
{
171+
var c = condIter.MoveNext();
172+
var xVal = xIter.MoveNext();
173+
var yVal = yIter.MoveNext();
174+
resultIter.MoveNextReference() = c ? xVal : yVal;
175+
}
176+
}
177+
178+
/// <summary>
179+
/// IL Kernel dispatch for contiguous arrays.
180+
/// Uses IL-generated kernels with SIMD optimization.
181+
/// </summary>
182+
private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType)
183+
{
184+
var condPtr = (bool*)cond.Address;
185+
var count = result.size;
186+
187+
switch (outType)
188+
{
189+
case NPTypeCode.Boolean:
190+
ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count);
191+
break;
192+
case NPTypeCode.Byte:
193+
ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count);
194+
break;
195+
case NPTypeCode.Int16:
196+
ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count);
197+
break;
198+
case NPTypeCode.UInt16:
199+
ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count);
200+
break;
201+
case NPTypeCode.Int32:
202+
ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count);
203+
break;
204+
case NPTypeCode.UInt32:
205+
ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count);
206+
break;
207+
case NPTypeCode.Int64:
208+
ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count);
209+
break;
210+
case NPTypeCode.UInt64:
211+
ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count);
212+
break;
213+
case NPTypeCode.Char:
214+
ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count);
215+
break;
216+
case NPTypeCode.Single:
217+
ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count);
218+
break;
219+
case NPTypeCode.Double:
220+
ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count);
221+
break;
222+
case NPTypeCode.Decimal:
223+
ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count);
224+
break;
225+
default:
226+
throw new NotSupportedException($"Type {outType} not supported for np.where");
227+
}
228+
}
229+
}
230+
}

0 commit comments

Comments
 (0)