|
| 1 | +using System; |
| 2 | +using NumSharp.Backends.Kernels; |
| 3 | +using NumSharp.Generic; |
| 4 | + |
| 5 | +namespace NumSharp |
| 6 | +{ |
| 7 | + public static partial class np |
| 8 | + { |
| 9 | + /// <summary> |
| 10 | + /// Equivalent to <see cref="nonzero(NDArray)"/>: returns the indices where |
| 11 | + /// <paramref name="condition"/> is non-zero. |
| 12 | + /// </summary> |
| 13 | + /// <param name="condition">Input array. Non-zero entries yield their indices.</param> |
| 14 | + /// <returns>Tuple of arrays with indices where condition is non-zero, one per dimension.</returns> |
| 15 | + /// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks> |
| 16 | + public static NDArray<long>[] where(NDArray condition) |
| 17 | + { |
| 18 | + return nonzero(condition); |
| 19 | + } |
| 20 | + |
| 21 | + /// <summary> |
| 22 | + /// Return elements chosen from `x` or `y` depending on `condition`. |
| 23 | + /// </summary> |
| 24 | + /// <param name="condition">Where True, yield `x`, otherwise yield `y`.</param> |
| 25 | + /// <param name="x">Values from which to choose where condition is True.</param> |
| 26 | + /// <param name="y">Values from which to choose where condition is False.</param> |
| 27 | + /// <returns>An array with elements from `x` where `condition` is True, and elements from `y` elsewhere.</returns> |
| 28 | + /// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks> |
| 29 | + public static NDArray where(NDArray condition, NDArray x, NDArray y) |
| 30 | + { |
| 31 | + return where_internal(condition, x, y); |
| 32 | + } |
| 33 | + |
| 34 | + /// <summary> |
| 35 | + /// Return elements chosen from `x` or `y` depending on `condition`. |
| 36 | + /// Scalar overload for x. |
| 37 | + /// </summary> |
| 38 | + public static NDArray where(NDArray condition, object x, NDArray y) |
| 39 | + { |
| 40 | + return where_internal(condition, asanyarray(x), y); |
| 41 | + } |
| 42 | + |
| 43 | + /// <summary> |
| 44 | + /// Return elements chosen from `x` or `y` depending on `condition`. |
| 45 | + /// Scalar overload for y. |
| 46 | + /// </summary> |
| 47 | + public static NDArray where(NDArray condition, NDArray x, object y) |
| 48 | + { |
| 49 | + return where_internal(condition, x, asanyarray(y)); |
| 50 | + } |
| 51 | + |
| 52 | + /// <summary> |
| 53 | + /// Return elements chosen from `x` or `y` depending on `condition`. |
| 54 | + /// Scalar overload for both x and y. |
| 55 | + /// </summary> |
| 56 | + public static NDArray where(NDArray condition, object x, object y) |
| 57 | + { |
| 58 | + return where_internal(condition, asanyarray(x), asanyarray(y)); |
| 59 | + } |
| 60 | + |
| 61 | + /// <summary> |
| 62 | + /// Internal implementation of np.where. |
| 63 | + /// </summary> |
| 64 | + private static NDArray where_internal(NDArray condition, NDArray x, NDArray y) |
| 65 | + { |
| 66 | + // Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three |
| 67 | + // already share a shape — the frequent case of np.where(mask, arr, other_arr). |
| 68 | + NDArray cond, xArr, yArr; |
| 69 | + if (condition.Shape == x.Shape && x.Shape == y.Shape) |
| 70 | + { |
| 71 | + cond = condition; |
| 72 | + xArr = x; |
| 73 | + yArr = y; |
| 74 | + } |
| 75 | + else |
| 76 | + { |
| 77 | + var broadcasted = broadcast_arrays(condition, x, y); |
| 78 | + cond = broadcasted[0]; |
| 79 | + xArr = broadcasted[1]; |
| 80 | + yArr = broadcasted[2]; |
| 81 | + } |
| 82 | + |
| 83 | + // When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to |
| 84 | + // _FindCommonType which handles the scalar+array NEP50 rules. |
| 85 | + var outType = x.GetTypeCode == y.GetTypeCode |
| 86 | + ? x.GetTypeCode |
| 87 | + : _FindCommonType(x, y); |
| 88 | + |
| 89 | + if (xArr.GetTypeCode != outType) |
| 90 | + xArr = xArr.astype(outType, copy: false); |
| 91 | + if (yArr.GetTypeCode != outType) |
| 92 | + yArr = yArr.astype(outType, copy: false); |
| 93 | + |
| 94 | + // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides) |
| 95 | + var result = empty(cond.shape, outType); |
| 96 | + |
| 97 | + // Handle empty arrays - nothing to iterate |
| 98 | + if (result.size == 0) |
| 99 | + return result; |
| 100 | + |
| 101 | + // IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled |
| 102 | + // Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path. |
| 103 | + bool canUseKernel = ILKernelGenerator.Enabled && |
| 104 | + cond.typecode == NPTypeCode.Boolean && |
| 105 | + cond.Shape.IsContiguous && |
| 106 | + xArr.Shape.IsContiguous && |
| 107 | + yArr.Shape.IsContiguous; |
| 108 | + |
| 109 | + if (canUseKernel) |
| 110 | + { |
| 111 | + WhereKernelDispatch(cond, xArr, yArr, result, outType); |
| 112 | + return result; |
| 113 | + } |
| 114 | + |
| 115 | + // Iterator fallback for non-contiguous/broadcasted arrays |
| 116 | + switch (outType) |
| 117 | + { |
| 118 | + case NPTypeCode.Boolean: |
| 119 | + WhereImpl<bool>(cond, xArr, yArr, result); |
| 120 | + break; |
| 121 | + case NPTypeCode.Byte: |
| 122 | + WhereImpl<byte>(cond, xArr, yArr, result); |
| 123 | + break; |
| 124 | + case NPTypeCode.Int16: |
| 125 | + WhereImpl<short>(cond, xArr, yArr, result); |
| 126 | + break; |
| 127 | + case NPTypeCode.UInt16: |
| 128 | + WhereImpl<ushort>(cond, xArr, yArr, result); |
| 129 | + break; |
| 130 | + case NPTypeCode.Int32: |
| 131 | + WhereImpl<int>(cond, xArr, yArr, result); |
| 132 | + break; |
| 133 | + case NPTypeCode.UInt32: |
| 134 | + WhereImpl<uint>(cond, xArr, yArr, result); |
| 135 | + break; |
| 136 | + case NPTypeCode.Int64: |
| 137 | + WhereImpl<long>(cond, xArr, yArr, result); |
| 138 | + break; |
| 139 | + case NPTypeCode.UInt64: |
| 140 | + WhereImpl<ulong>(cond, xArr, yArr, result); |
| 141 | + break; |
| 142 | + case NPTypeCode.Char: |
| 143 | + WhereImpl<char>(cond, xArr, yArr, result); |
| 144 | + break; |
| 145 | + case NPTypeCode.Single: |
| 146 | + WhereImpl<float>(cond, xArr, yArr, result); |
| 147 | + break; |
| 148 | + case NPTypeCode.Double: |
| 149 | + WhereImpl<double>(cond, xArr, yArr, result); |
| 150 | + break; |
| 151 | + case NPTypeCode.Decimal: |
| 152 | + WhereImpl<decimal>(cond, xArr, yArr, result); |
| 153 | + break; |
| 154 | + default: |
| 155 | + throw new NotSupportedException($"Type {outType} not supported for np.where"); |
| 156 | + } |
| 157 | + |
| 158 | + return result; |
| 159 | + } |
| 160 | + |
| 161 | + private static void WhereImpl<T>(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged |
| 162 | + { |
| 163 | + // Use iterators for proper handling of broadcasted/strided arrays |
| 164 | + using var condIter = cond.AsIterator<bool>(); |
| 165 | + using var xIter = x.AsIterator<T>(); |
| 166 | + using var yIter = y.AsIterator<T>(); |
| 167 | + using var resultIter = result.AsIterator<T>(); |
| 168 | + |
| 169 | + while (condIter.HasNext()) |
| 170 | + { |
| 171 | + var c = condIter.MoveNext(); |
| 172 | + var xVal = xIter.MoveNext(); |
| 173 | + var yVal = yIter.MoveNext(); |
| 174 | + resultIter.MoveNextReference() = c ? xVal : yVal; |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + /// <summary> |
| 179 | + /// IL Kernel dispatch for contiguous arrays. |
| 180 | + /// Uses IL-generated kernels with SIMD optimization. |
| 181 | + /// </summary> |
| 182 | + private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType) |
| 183 | + { |
| 184 | + var condPtr = (bool*)cond.Address; |
| 185 | + var count = result.size; |
| 186 | + |
| 187 | + switch (outType) |
| 188 | + { |
| 189 | + case NPTypeCode.Boolean: |
| 190 | + ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count); |
| 191 | + break; |
| 192 | + case NPTypeCode.Byte: |
| 193 | + ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count); |
| 194 | + break; |
| 195 | + case NPTypeCode.Int16: |
| 196 | + ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count); |
| 197 | + break; |
| 198 | + case NPTypeCode.UInt16: |
| 199 | + ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count); |
| 200 | + break; |
| 201 | + case NPTypeCode.Int32: |
| 202 | + ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count); |
| 203 | + break; |
| 204 | + case NPTypeCode.UInt32: |
| 205 | + ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count); |
| 206 | + break; |
| 207 | + case NPTypeCode.Int64: |
| 208 | + ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count); |
| 209 | + break; |
| 210 | + case NPTypeCode.UInt64: |
| 211 | + ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count); |
| 212 | + break; |
| 213 | + case NPTypeCode.Char: |
| 214 | + ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count); |
| 215 | + break; |
| 216 | + case NPTypeCode.Single: |
| 217 | + ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count); |
| 218 | + break; |
| 219 | + case NPTypeCode.Double: |
| 220 | + ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count); |
| 221 | + break; |
| 222 | + case NPTypeCode.Decimal: |
| 223 | + ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count); |
| 224 | + break; |
| 225 | + default: |
| 226 | + throw new NotSupportedException($"Type {outType} not supported for np.where"); |
| 227 | + } |
| 228 | + } |
| 229 | + } |
| 230 | +} |
0 commit comments