Skip to content

Commit bd5f5d7

Browse files
committed
feat(dtypes): Half/SByte/Complex coverage audit + NumPy parity fixes
Audit of "NPTypeCode.Single =>" arrow-switch expressions across 23 files found 11 gaps where Half/SByte/Complex were missing. Fixed each and tightened several related behaviors to strict NumPy 2.x parity. ## Core fixes (Half/SByte/Complex coverage) - np.repeat: add SByte/Half/Complex to RepeatScalarTyped/RepeatArrayTyped switches. Previously threw NotSupportedException for these dtypes. - np.any / np.all axis: add SByte/Half/Complex to axis dispatch (generic ComputePerAxis<T> already supports via unmanaged constraint). - ILKernelGenerator.Reduction.Axis.Arg: add SByte/Half/Complex to argmax/argmin axis dispatch. Added ArgReduceAxisHalfNaN (NumPy first-NaN-wins semantics via double) and ArgReduceAxisComplex (lexicographic real-then-imag, NaN propagates). sbyte added to CompareGreater/Less. - ReductionKernel.GetMinValue/GetMaxValue: add SByte/Half/Complex identities (sbyte.Min/MaxValue, Half.Negative/PositiveInfinity, Complex(inf,0) sentinels for Max/Min identity on empty arrays). - Default.Reduction.Nan ExecuteNanAxisReductionScalar: add Half case + ReduceNanAxisScalarHalf helper covering NanSum/NanProd/NanMin/NanMax. Previously silently returned 0 for Half axis NaN reductions. - ILKernelGenerator.Reduction.Axis.NaN: updated doc comment clarifying Half/Complex route to scalar fallback (resolved by the above fix). - Default.ATan2: add SByte/Half to ConvertToDouble/ConvertToDecimal and Half to result-type switch. Complex excluded (NumPy arctan2 rejects complex inputs — matches np.arctan2 TypeError). - np.can_cast ValueFitsInType: add Half (range-checked ±65504) and Complex (always true from real) to every `case` arm; added `case Half h:` and `case Complex c:`. Full 13×13 can_cast matrix now matches NumPy exactly. - ILKernelGenerator EmitDecimalConversion: added SByte conversion via new CachedMethods.DecimalImplicitFromSByte / DecimalToSByte. Previously sbyte↔decimal IL conversions threw NotSupportedException. - np.sctype2char: fix Boolean '?' (was incorrectly 'b'), add SByte 'b', add Half 'e'. Matches NumPy 2.x np.dtype(x).char. ## Strict-parity fixes discovered during verification - ATan2 auto-promotion now matches NumPy 2.x per-input targeting: bool/i8/u8 → float16, i16/u16 → float32, i32+/i64+/char → float64, float types preserved, binary takes max. Added PromoteATan2Single + PromoteATan2Binary helpers. Previously everything except f32+f32 promoted to double. - common_type_code rewritten to match NumPy exactly: * Boolean input: raises TypeError "non-numeric array" (NumPy parity) * Any Complex → Complex * Any Decimal → Decimal (NumSharp extension) * Any integer/char → Double (forces float64 even if smaller float present) * Otherwise: max pure float (Half < Single < Double) 12×12 matrix now identically matches NumPy. - Empty reduction dtype: sum/prod of empty array now uses GetAccumulatingType() so int/bool → Int64/UInt64, floats preserved. Previously returned input dtype (sum([], sbyte) gave SByte, NumPy gives int64). Fixed in Default.Reduction.Add (HandleEmptyArrayReduction + IsEmpty path) and Default.Reduction.Product (both IsEmpty paths). ## Test additions - test/NumSharp.UnitTest/APIs/np.common_type.BattleTest.cs: Complete rewrite — 77 comprehensive tests covering: - Boolean TypeError (5 tests) - Single integer inputs → Double (9 tests) - Single float preserved (3 tests) - Complex/Decimal (2 tests) - Pure float combos → max float (9 tests) - Integer+Integer combos (7 tests) - Integer+Float combos (10 tests) - Complex combos (9 tests) - Decimal combos with float/int/complex (5 tests) - NDArray / Type overloads (12 tests) - Argument validation (3 tests) - test/NumSharp.UnitTest/Backends/Kernels/BinaryOpTests.cs: 8 new ATan2_* tests pinning Half/SByte/Int16 NumPy parity: ATan2_Float16_ReturnsHalf, ATan2_Int8_ReturnsFloat16, ATan2_UInt8_ReturnsFloat16, ATan2_Int16_ReturnsFloat32, ATan2_Float16_Int8_ReturnsFloat16, ATan2_Float16_Int32_ReturnsFloat64, ATan2_Int16_Float16_ReturnsFloat32. - test/NumSharp.UnitTest/APIs/np.type_checks.BattleTest.cs: Updated Sctype2Char_Boolean to expect '?' (matches NumPy); added Sctype2Char_SByte ('b') and Sctype2Char_Half ('e'). ## Verification methodology Every change verified against NumPy 2.x via python_run reference runs. Side-by-side 13×13 can_cast grid and 12×12 common_type grid both produce identical output to NumPy. Cast correctness (Half↔double) is lossless per IEEE 754 and matches NumPy's internal float16 handling. ## Test results 7192 passed / 0 failed / 11 skipped on both net8.0 and net10.0. (+63 net tests vs pre-audit; rewritten common_type suite replaced 14 older tests with 77 parity-locked ones.) ## Behavioral breaking changes (NumPy parity) - np.sctype2char(Boolean): 'b' → '?' - np.common_type(Boolean): returned Double → now throws TypeError - np.arctan2(i8/u8/bool): returned Double → now returns Half - np.arctan2(i16/u16): returned Double → now returns Single - np.arctan2(f16): returned Double → now returns Half - np.sum/np.prod of empty integer array: returned input dtype → now returns Int64/UInt64 accumulating type
1 parent b36555f commit bd5f5d7

17 files changed

Lines changed: 677 additions & 126 deletions

src/NumSharp.Core/Backends/Default/Math/Default.ATan2.cs

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,15 @@ private unsafe NDArray ExecuteATan2Op(NDArray y, NDArray x, NPTypeCode? typeCode
4949
var yType = y.GetTypeCode;
5050
var xType = x.GetTypeCode;
5151

52-
// Determine result type using NumPy arctan2 rules:
53-
// - float32 inputs -> float32 output
54-
// - float64 or integer inputs -> float64 output
55-
NPTypeCode resultType;
56-
if (typeCode.HasValue)
57-
{
58-
resultType = typeCode.Value;
59-
}
60-
else
61-
{
62-
// NumPy arctan2 type promotion:
63-
// float32 + float32 -> float32
64-
// anything else -> float64
65-
if (yType == NPTypeCode.Single && xType == NPTypeCode.Single)
66-
{
67-
resultType = NPTypeCode.Single;
68-
}
69-
else if (yType == NPTypeCode.Decimal || xType == NPTypeCode.Decimal)
70-
{
71-
resultType = NPTypeCode.Decimal;
72-
}
73-
else
74-
{
75-
resultType = NPTypeCode.Double;
76-
}
77-
}
52+
// Determine result type using NumPy 2.x arctan2 rules.
53+
// Each input maps to its smallest supporting float target:
54+
// bool / int8 / uint8 -> float16
55+
// int16 / uint16 -> float32
56+
// int32+ / int64+ / char -> float64
57+
// float16 / float32 / float64-> same
58+
// decimal (NumSharp ext.) -> decimal
59+
// The result is the larger of the two promotion targets.
60+
NPTypeCode resultType = typeCode ?? PromoteATan2Binary(yType, xType);
7861

7962
// Handle scalar x scalar case
8063
if (y.Shape.IsScalar && x.Shape.IsScalar)
@@ -119,6 +102,47 @@ private unsafe NDArray ExecuteATan2Op(NDArray y, NDArray x, NPTypeCode? typeCode
119102
return result;
120103
}
121104

105+
/// <summary>
106+
/// Maps a single input dtype to its NumPy arctan2 output target.
107+
/// NumPy 2.x rules: bool/i8/u8 → f16, i16/u16 → f32, i32+/i64+/char → f64,
108+
/// float types preserved, decimal preserved (NumSharp extension).
109+
/// </summary>
110+
private static NPTypeCode PromoteATan2Single(NPTypeCode t) => t switch
111+
{
112+
NPTypeCode.Boolean or NPTypeCode.SByte or NPTypeCode.Byte => NPTypeCode.Half,
113+
NPTypeCode.Int16 or NPTypeCode.UInt16 => NPTypeCode.Single,
114+
NPTypeCode.Int32 or NPTypeCode.UInt32 or NPTypeCode.Int64 or NPTypeCode.UInt64 or NPTypeCode.Char => NPTypeCode.Double,
115+
NPTypeCode.Half => NPTypeCode.Half,
116+
NPTypeCode.Single => NPTypeCode.Single,
117+
NPTypeCode.Double => NPTypeCode.Double,
118+
NPTypeCode.Decimal => NPTypeCode.Decimal,
119+
_ => NPTypeCode.Double,
120+
};
121+
122+
/// <summary>
123+
/// Binary promotion for arctan2: take the "larger" of the two single-input targets.
124+
/// Order: Decimal > Double > Single > Half.
125+
/// </summary>
126+
private static NPTypeCode PromoteATan2Binary(NPTypeCode y, NPTypeCode x)
127+
{
128+
var py = PromoteATan2Single(y);
129+
var px = PromoteATan2Single(x);
130+
if (py == px) return py;
131+
132+
// Decimal dominates (NumSharp extension).
133+
if (py == NPTypeCode.Decimal || px == NPTypeCode.Decimal) return NPTypeCode.Decimal;
134+
135+
// Otherwise: larger float wins (Double > Single > Half).
136+
static int Rank(NPTypeCode t) => t switch
137+
{
138+
NPTypeCode.Half => 1,
139+
NPTypeCode.Single => 2,
140+
NPTypeCode.Double => 3,
141+
_ => 3,
142+
};
143+
return Rank(py) >= Rank(px) ? py : px;
144+
}
145+
122146
/// <summary>
123147
/// Execute scalar x scalar ATan2 operation.
124148
/// </summary>
@@ -135,6 +159,7 @@ private static NDArray ExecuteATan2ScalarScalar(
135159
// Convert to result type
136160
return resultType switch
137161
{
162+
NPTypeCode.Half => NDArray.Scalar((Half)result),
138163
NPTypeCode.Single => NDArray.Scalar((float)result),
139164
NPTypeCode.Double => NDArray.Scalar(result),
140165
NPTypeCode.Decimal => NDArray.Scalar(Utilities.DecimalMath.ATan2(
@@ -152,16 +177,19 @@ private static double ConvertToDouble(NDArray arr, NPTypeCode type)
152177
{
153178
NPTypeCode.Boolean => arr.GetBoolean(Array.Empty<long>()) ? 1.0 : 0.0,
154179
NPTypeCode.Byte => arr.GetByte(Array.Empty<long>()),
180+
NPTypeCode.SByte => arr.GetSByte(Array.Empty<long>()),
155181
NPTypeCode.Int16 => arr.GetInt16(Array.Empty<long>()),
156182
NPTypeCode.UInt16 => arr.GetUInt16(Array.Empty<long>()),
157183
NPTypeCode.Int32 => arr.GetInt32(Array.Empty<long>()),
158184
NPTypeCode.UInt32 => arr.GetUInt32(Array.Empty<long>()),
159185
NPTypeCode.Int64 => arr.GetInt64(Array.Empty<long>()),
160186
NPTypeCode.UInt64 => arr.GetUInt64(Array.Empty<long>()),
161187
NPTypeCode.Char => arr.GetChar(Array.Empty<long>()),
188+
NPTypeCode.Half => (double)arr.GetHalf(Array.Empty<long>()),
162189
NPTypeCode.Single => arr.GetSingle(Array.Empty<long>()),
163190
NPTypeCode.Double => arr.GetDouble(Array.Empty<long>()),
164191
NPTypeCode.Decimal => (double)arr.GetDecimal(Array.Empty<long>()),
192+
// NumPy's arctan2 is real-valued; complex inputs are not supported.
165193
_ => throw new NotSupportedException($"Type {type} not supported")
166194
};
167195
}
@@ -175,13 +203,15 @@ private static decimal ConvertToDecimal(NDArray arr, NPTypeCode type)
175203
{
176204
NPTypeCode.Boolean => arr.GetBoolean(Array.Empty<long>()) ? 1m : 0m,
177205
NPTypeCode.Byte => arr.GetByte(Array.Empty<long>()),
206+
NPTypeCode.SByte => arr.GetSByte(Array.Empty<long>()),
178207
NPTypeCode.Int16 => arr.GetInt16(Array.Empty<long>()),
179208
NPTypeCode.UInt16 => arr.GetUInt16(Array.Empty<long>()),
180209
NPTypeCode.Int32 => arr.GetInt32(Array.Empty<long>()),
181210
NPTypeCode.UInt32 => arr.GetUInt32(Array.Empty<long>()),
182211
NPTypeCode.Int64 => arr.GetInt64(Array.Empty<long>()),
183212
NPTypeCode.UInt64 => arr.GetUInt64(Array.Empty<long>()),
184213
NPTypeCode.Char => arr.GetChar(Array.Empty<long>()),
214+
NPTypeCode.Half => (decimal)(double)arr.GetHalf(Array.Empty<long>()),
185215
NPTypeCode.Single => (decimal)arr.GetSingle(Array.Empty<long>()),
186216
NPTypeCode.Double => (decimal)arr.GetDouble(Array.Empty<long>()),
187217
NPTypeCode.Decimal => arr.GetDecimal(Array.Empty<long>()),

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.Add.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ public override NDArray ReduceAdd(NDArray arr, int? axis_, bool keepdims = false
1212

1313
if (shape.IsEmpty)
1414
{
15-
var defaultVal = (typeCode ?? arr.typecode).GetDefaultValue();
15+
// NumPy parity: sum of empty array uses accumulating type (int/bool -> int64/uint64, floats preserved).
16+
var defaultType = typeCode ?? arr.typecode.GetAccumulatingType();
17+
var defaultVal = defaultType.GetDefaultValue();
1618
if (@out is not null) { @out.SetAtIndex(defaultVal, 0); return @out; }
1719
return NDArray.Scalar(defaultVal);
1820
}
@@ -125,7 +127,9 @@ private NDArray HandleEmptyArrayReduction(NDArray arr, int? axis_, bool keepdims
125127
var shape = arr.Shape;
126128
if (axis_ == null)
127129
{
128-
var defaultVal = (typeCode ?? arr.typecode).GetDefaultValue();
130+
// NumPy parity: empty reduction uses accumulating type (int/bool -> int64/uint64, floats preserved).
131+
var defaultType = typeCode ?? arr.typecode.GetAccumulatingType();
132+
var defaultVal = defaultType.GetDefaultValue();
129133
if (@out is not null) { @out.SetAtIndex(defaultVal, 0); return @out; }
130134
var r = NDArray.Scalar(defaultVal);
131135
if (keepdims) { var ks = new long[arr.ndim]; for (int i = 0; i < arr.ndim; i++) ks[i] = 1; r.Storage.Reshape(new Shape(ks)); }

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.Nan.cs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,9 @@ private NDArray ExecuteNanAxisReductionScalar(NDArray arr, int axis, bool keepdi
545545
case NPTypeCode.Double:
546546
reduced = ReduceNanAxisScalarDouble(arr, inputBaseOffset, axisSize, shape.strides[axis], op);
547547
break;
548+
case NPTypeCode.Half:
549+
reduced = ReduceNanAxisScalarHalf(arr, inputBaseOffset, axisSize, shape.strides[axis], op);
550+
break;
548551
default:
549552
reduced = 0;
550553
break;
@@ -665,6 +668,61 @@ private static double ReduceNanAxisScalarDouble(NDArray arr, long baseOffset, lo
665668
}
666669
}
667670

671+
/// <summary>
672+
/// Half-typed scalar NaN axis reduction. Uses double accumulator for precision,
673+
/// casts final result back to Half to preserve dtype.
674+
/// </summary>
675+
private static Half ReduceNanAxisScalarHalf(NDArray arr, long baseOffset, long axisSize, long axisStride, ReductionOp op)
676+
{
677+
switch (op)
678+
{
679+
case ReductionOp.NanSum:
680+
{
681+
double sum = 0.0;
682+
for (long i = 0; i < axisSize; i++)
683+
{
684+
double val = (double)(Half)arr.GetAtIndex(baseOffset + i * axisStride);
685+
if (!double.IsNaN(val)) sum += val;
686+
}
687+
return (Half)sum;
688+
}
689+
case ReductionOp.NanProd:
690+
{
691+
double prod = 1.0;
692+
for (long i = 0; i < axisSize; i++)
693+
{
694+
double val = (double)(Half)arr.GetAtIndex(baseOffset + i * axisStride);
695+
if (!double.IsNaN(val)) prod *= val;
696+
}
697+
return (Half)prod;
698+
}
699+
case ReductionOp.NanMin:
700+
{
701+
double minVal = double.PositiveInfinity;
702+
bool foundNonNaN = false;
703+
for (long i = 0; i < axisSize; i++)
704+
{
705+
double val = (double)(Half)arr.GetAtIndex(baseOffset + i * axisStride);
706+
if (!double.IsNaN(val)) { if (val < minVal) minVal = val; foundNonNaN = true; }
707+
}
708+
return foundNonNaN ? (Half)minVal : Half.NaN;
709+
}
710+
case ReductionOp.NanMax:
711+
{
712+
double maxVal = double.NegativeInfinity;
713+
bool foundNonNaN = false;
714+
for (long i = 0; i < axisSize; i++)
715+
{
716+
double val = (double)(Half)arr.GetAtIndex(baseOffset + i * axisStride);
717+
if (!double.IsNaN(val)) { if (val > maxVal) maxVal = val; foundNonNaN = true; }
718+
}
719+
return foundNonNaN ? (Half)maxVal : Half.NaN;
720+
}
721+
default:
722+
return Half.Zero;
723+
}
724+
}
725+
668726
/// <summary>
669727
/// B15: NumPy-parity Complex nansum. Treats any element with NaN in real OR imag
670728
/// as zero (skipped). Sum type is Complex.

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.Product.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@ public override NDArray ReduceProduct(NDArray arr, int? axis_, bool keepdims = f
1111
var shape = arr.Shape;
1212

1313
if (shape.IsEmpty)
14-
return NDArray.Scalar((typeCode ?? arr.typecode).GetOneValue());
14+
{
15+
// NumPy parity: prod of empty array uses accumulating type (int/bool -> int64/uint64, floats preserved).
16+
var emptyType = typeCode ?? arr.typecode.GetAccumulatingType();
17+
return NDArray.Scalar(emptyType.GetOneValue());
18+
}
1519

1620
if (shape.size == 0)
1721
{
1822
if (axis_ == null)
1923
{
20-
var r = NDArray.Scalar((typeCode ?? arr.typecode).GetOneValue());
24+
var emptyType = typeCode ?? arr.typecode.GetAccumulatingType();
25+
var r = NDArray.Scalar(emptyType.GetOneValue());
2126
if (keepdims) { var ks = new long[arr.ndim]; for (int i = 0; i < arr.ndim; i++) ks[i] = 1; r.Storage.Reshape(new Shape(ks)); }
2227
return r;
2328
}

src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Reduction.Axis.Arg.cs

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@ private static AxisReductionKernel CreateAxisArgReductionKernel(AxisReductionKer
3030
{
3131
NPTypeCode.Boolean => CreateAxisArgReductionKernelTyped<bool>(key),
3232
NPTypeCode.Byte => CreateAxisArgReductionKernelTyped<byte>(key),
33+
NPTypeCode.SByte => CreateAxisArgReductionKernelTyped<sbyte>(key),
3334
NPTypeCode.Int16 => CreateAxisArgReductionKernelTyped<short>(key),
3435
NPTypeCode.UInt16 => CreateAxisArgReductionKernelTyped<ushort>(key),
3536
NPTypeCode.Int32 => CreateAxisArgReductionKernelTyped<int>(key),
3637
NPTypeCode.UInt32 => CreateAxisArgReductionKernelTyped<uint>(key),
3738
NPTypeCode.Int64 => CreateAxisArgReductionKernelTyped<long>(key),
3839
NPTypeCode.UInt64 => CreateAxisArgReductionKernelTyped<ulong>(key),
3940
NPTypeCode.Char => CreateAxisArgReductionKernelTyped<char>(key),
41+
NPTypeCode.Half => CreateAxisArgReductionKernelTyped<Half>(key),
4042
NPTypeCode.Single => CreateAxisArgReductionKernelTyped<float>(key),
4143
NPTypeCode.Double => CreateAxisArgReductionKernelTyped<double>(key),
4244
NPTypeCode.Decimal => CreateAxisArgReductionKernelTyped<decimal>(key),
45+
NPTypeCode.Complex => CreateAxisArgReductionKernelTyped<System.Numerics.Complex>(key),
4346
_ => throw new NotSupportedException($"ArgMax/ArgMin not supported for type {key.InputType}")
4447
};
4548
}
@@ -133,6 +136,14 @@ private static unsafe long ArgReduceAxis<T>(T* data, long size, long stride, Red
133136
{
134137
return ArgReduceAxisDoubleNaN((double*)data, size, stride, op);
135138
}
139+
if (typeof(T) == typeof(Half))
140+
{
141+
return ArgReduceAxisHalfNaN((Half*)data, size, stride, op);
142+
}
143+
if (typeof(T) == typeof(System.Numerics.Complex))
144+
{
145+
return ArgReduceAxisComplex((System.Numerics.Complex*)data, size, stride, op);
146+
}
136147
// Handle boolean specially
137148
if (typeof(T) == typeof(bool))
138149
{
@@ -307,6 +318,7 @@ private static unsafe long ArgReduceAxisNumeric<T>(T* data, long size, long stri
307318
private static bool CompareGreater<T>(T a, T b) where T : unmanaged
308319
{
309320
if (typeof(T) == typeof(byte)) return (byte)(object)a > (byte)(object)b;
321+
if (typeof(T) == typeof(sbyte)) return (sbyte)(object)a > (sbyte)(object)b;
310322
if (typeof(T) == typeof(short)) return (short)(object)a > (short)(object)b;
311323
if (typeof(T) == typeof(ushort)) return (ushort)(object)a > (ushort)(object)b;
312324
if (typeof(T) == typeof(int)) return (int)(object)a > (int)(object)b;
@@ -315,7 +327,7 @@ private static bool CompareGreater<T>(T a, T b) where T : unmanaged
315327
if (typeof(T) == typeof(ulong)) return (ulong)(object)a > (ulong)(object)b;
316328
if (typeof(T) == typeof(char)) return (char)(object)a > (char)(object)b;
317329
if (typeof(T) == typeof(decimal)) return (decimal)(object)a > (decimal)(object)b;
318-
// Float/double handled separately with NaN awareness
330+
// Float/double/Half/Complex handled separately
319331
throw new NotSupportedException($"CompareGreater not supported for type {typeof(T)}");
320332
}
321333

@@ -325,6 +337,7 @@ private static bool CompareGreater<T>(T a, T b) where T : unmanaged
325337
private static bool CompareLess<T>(T a, T b) where T : unmanaged
326338
{
327339
if (typeof(T) == typeof(byte)) return (byte)(object)a < (byte)(object)b;
340+
if (typeof(T) == typeof(sbyte)) return (sbyte)(object)a < (sbyte)(object)b;
328341
if (typeof(T) == typeof(short)) return (short)(object)a < (short)(object)b;
329342
if (typeof(T) == typeof(ushort)) return (ushort)(object)a < (ushort)(object)b;
330343
if (typeof(T) == typeof(int)) return (int)(object)a < (int)(object)b;
@@ -333,10 +346,83 @@ private static bool CompareLess<T>(T a, T b) where T : unmanaged
333346
if (typeof(T) == typeof(ulong)) return (ulong)(object)a < (ulong)(object)b;
334347
if (typeof(T) == typeof(char)) return (char)(object)a < (char)(object)b;
335348
if (typeof(T) == typeof(decimal)) return (decimal)(object)a < (decimal)(object)b;
336-
// Float/double handled separately with NaN awareness
349+
// Float/double/Half/Complex handled separately
337350
throw new NotSupportedException($"CompareLess not supported for type {typeof(T)}");
338351
}
339352

353+
/// <summary>
354+
/// ArgMax/ArgMin for Half with NaN awareness.
355+
/// NumPy behavior: first NaN always wins. IL OpCodes.Bgt/Blt don't work on Half;
356+
/// compare via (double) cast.
357+
/// </summary>
358+
private static unsafe long ArgReduceAxisHalfNaN(Half* data, long size, long stride, ReductionOp op)
359+
{
360+
double extreme = (double)data[0];
361+
long extremeIdx = 0;
362+
363+
for (long i = 1; i < size; i++)
364+
{
365+
double val = (double)data[i * stride];
366+
367+
if (double.IsNaN(val) && !double.IsNaN(extreme))
368+
{
369+
extreme = val;
370+
extremeIdx = i;
371+
}
372+
else if (!double.IsNaN(extreme))
373+
{
374+
if (op == ReductionOp.ArgMax)
375+
{
376+
if (val > extreme) { extreme = val; extremeIdx = i; }
377+
}
378+
else
379+
{
380+
if (val < extreme) { extreme = val; extremeIdx = i; }
381+
}
382+
}
383+
}
384+
385+
return extremeIdx;
386+
}
387+
388+
/// <summary>
389+
/// ArgMax/ArgMin for Complex using lexicographic compare (real, then imag).
390+
/// NumPy propagates NaN: a Complex value with NaN in either component wins at its first occurrence.
391+
/// </summary>
392+
private static unsafe long ArgReduceAxisComplex(System.Numerics.Complex* data, long size, long stride, ReductionOp op)
393+
{
394+
var extreme = data[0];
395+
long extremeIdx = 0;
396+
if (double.IsNaN(extreme.Real) || double.IsNaN(extreme.Imaginary))
397+
return 0;
398+
399+
for (long i = 1; i < size; i++)
400+
{
401+
var val = data[i * stride];
402+
if (double.IsNaN(val.Real) || double.IsNaN(val.Imaginary))
403+
return i;
404+
405+
if (op == ReductionOp.ArgMax)
406+
{
407+
if (val.Real > extreme.Real || (val.Real == extreme.Real && val.Imaginary > extreme.Imaginary))
408+
{
409+
extreme = val;
410+
extremeIdx = i;
411+
}
412+
}
413+
else
414+
{
415+
if (val.Real < extreme.Real || (val.Real == extreme.Real && val.Imaginary < extreme.Imaginary))
416+
{
417+
extreme = val;
418+
extremeIdx = i;
419+
}
420+
}
421+
}
422+
423+
return extremeIdx;
424+
}
425+
340426
#endregion
341427
}
342428
}

0 commit comments

Comments
 (0)