Skip to content

Commit c9234e2

Browse files
Merge pull request #2 from SciSharp/master
sync
2 parents b643bda + b4765e5 commit c9234e2

7 files changed

Lines changed: 487 additions & 25 deletions

File tree

src/NumSharp.Core/Logic/np.all.cs

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ public static bool all(NDArray a)
2626
#else
2727

2828
#region Compute
29-
switch (a.typecode)
30-
{
31-
case NPTypeCode.Boolean: return _all_linear<bool>(a.MakeGeneric<bool>());
32-
case NPTypeCode.Byte: return _all_linear<byte>(a.MakeGeneric<byte>());
33-
case NPTypeCode.Int16: return _all_linear<short>(a.MakeGeneric<short>());
34-
case NPTypeCode.UInt16: return _all_linear<ushort>(a.MakeGeneric<ushort>());
35-
case NPTypeCode.Int32: return _all_linear<int>(a.MakeGeneric<int>());
36-
case NPTypeCode.UInt32: return _all_linear<uint>(a.MakeGeneric<uint>());
37-
case NPTypeCode.Int64: return _all_linear<long>(a.MakeGeneric<long>());
38-
case NPTypeCode.UInt64: return _all_linear<ulong>(a.MakeGeneric<ulong>());
39-
case NPTypeCode.Char: return _all_linear<char>(a.MakeGeneric<char>());
40-
case NPTypeCode.Double: return _all_linear<double>(a.MakeGeneric<double>());
41-
case NPTypeCode.Single: return _all_linear<float>(a.MakeGeneric<float>());
42-
case NPTypeCode.Decimal: return _all_linear<decimal>(a.MakeGeneric<decimal>());
43-
default:
44-
throw new NotSupportedException();
45-
}
29+
switch (a.typecode)
30+
{
31+
case NPTypeCode.Boolean: return _all_linear<bool>(a.MakeGeneric<bool>());
32+
case NPTypeCode.Byte: return _all_linear<byte>(a.MakeGeneric<byte>());
33+
case NPTypeCode.Int16: return _all_linear<short>(a.MakeGeneric<short>());
34+
case NPTypeCode.UInt16: return _all_linear<ushort>(a.MakeGeneric<ushort>());
35+
case NPTypeCode.Int32: return _all_linear<int>(a.MakeGeneric<int>());
36+
case NPTypeCode.UInt32: return _all_linear<uint>(a.MakeGeneric<uint>());
37+
case NPTypeCode.Int64: return _all_linear<long>(a.MakeGeneric<long>());
38+
case NPTypeCode.UInt64: return _all_linear<ulong>(a.MakeGeneric<ulong>());
39+
case NPTypeCode.Char: return _all_linear<char>(a.MakeGeneric<char>());
40+
case NPTypeCode.Double: return _all_linear<double>(a.MakeGeneric<double>());
41+
case NPTypeCode.Single: return _all_linear<float>(a.MakeGeneric<float>());
42+
case NPTypeCode.Decimal: return _all_linear<decimal>(a.MakeGeneric<decimal>());
43+
default:
44+
throw new NotSupportedException();
45+
}
4646
#endregion
4747
#endif
4848
}
@@ -51,12 +51,113 @@ public static bool all(NDArray a)
5151
/// Test whether all array elements along a given axis evaluate to True.
5252
/// </summary>
5353
/// <param name="a">Input array or object that can be converted to an array.</param>
54-
/// <param name="axis">Axis or axes along which a logical OR reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
54+
/// <param name="axis">Axis or axes along which a logical AND reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
5555
/// <returns>A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.</returns>
5656
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.all.html</remarks>
57-
public static NDArray<bool> all(NDArray nd, int axis)
57+
public static NDArray<bool> all(NDArray nd, int axis, bool keepdims = false)
5858
{
59-
throw new NotImplementedException(); //TODO
59+
if (axis < 0)
60+
axis = nd.ndim + axis;
61+
if (axis < 0 || axis >= nd.ndim)
62+
{
63+
throw new ArgumentOutOfRangeException(nameof(axis));
64+
}
65+
if (nd.ndim == 0)
66+
{
67+
throw new ArgumentException("Can't operate with zero array");
68+
}
69+
if (nd == null)
70+
{
71+
throw new ArgumentException("Can't operate with null array");
72+
}
73+
74+
int[] inputShape = nd.shape;
75+
int[] outputShape = new int[keepdims ? inputShape.Length : inputShape.Length - 1];
76+
int outputIndex = 0;
77+
for (int i = 0; i < inputShape.Length; i++)
78+
{
79+
if (i != axis)
80+
{
81+
outputShape[outputIndex++] = inputShape[i];
82+
}
83+
else if (keepdims)
84+
{
85+
outputShape[outputIndex++] = 1; // keep axis but length is one.
86+
}
87+
}
88+
89+
NDArray<bool> resultArray = (NDArray<bool>)zeros<bool>(outputShape);
90+
Span<bool> resultSpan = resultArray.GetData().AsSpan<bool>();
91+
92+
int axisSize = inputShape[axis];
93+
94+
// It help to build an index
95+
int preAxisStride = 1;
96+
for (int i = 0; i < axis; i++)
97+
{
98+
preAxisStride *= inputShape[i];
99+
}
100+
101+
int postAxisStride = 1;
102+
for (int i = axis + 1; i < inputShape.Length; i++)
103+
{
104+
postAxisStride *= inputShape[i];
105+
}
106+
107+
108+
// Operate different logic by TypeCode
109+
bool computationSuccess = false;
110+
switch (nd.typecode)
111+
{
112+
case NPTypeCode.Boolean: computationSuccess = ComputeAllPerAxis<bool>(nd.MakeGeneric<bool>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
113+
case NPTypeCode.Byte: computationSuccess = ComputeAllPerAxis<byte>(nd.MakeGeneric<byte>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
114+
case NPTypeCode.Int16: computationSuccess = ComputeAllPerAxis<short>(nd.MakeGeneric<short>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
115+
case NPTypeCode.UInt16: computationSuccess = ComputeAllPerAxis<ushort>(nd.MakeGeneric<ushort>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
116+
case NPTypeCode.Int32: computationSuccess = ComputeAllPerAxis<int>(nd.MakeGeneric<int>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
117+
case NPTypeCode.UInt32: computationSuccess = ComputeAllPerAxis<uint>(nd.MakeGeneric<uint>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
118+
case NPTypeCode.Int64: computationSuccess = ComputeAllPerAxis<long>(nd.MakeGeneric<long>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
119+
case NPTypeCode.UInt64: computationSuccess = ComputeAllPerAxis<ulong>(nd.MakeGeneric<ulong>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
120+
case NPTypeCode.Char: computationSuccess = ComputeAllPerAxis<char>(nd.MakeGeneric<char>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
121+
case NPTypeCode.Double: computationSuccess = ComputeAllPerAxis<double>(nd.MakeGeneric<double>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
122+
case NPTypeCode.Single: computationSuccess = ComputeAllPerAxis<float>(nd.MakeGeneric<float>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
123+
case NPTypeCode.Decimal: computationSuccess = ComputeAllPerAxis<decimal>(nd.MakeGeneric<decimal>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
124+
default:
125+
throw new NotSupportedException($"Type {nd.typecode} is not supported");
126+
}
127+
128+
if (!computationSuccess)
129+
{
130+
throw new InvalidOperationException("Failed to compute all() along the specified axis");
131+
}
132+
133+
return resultArray;
134+
}
135+
136+
private static bool ComputeAllPerAxis<T>(NDArray<T> nd, int axis, int preAxisStride, int postAxisStride, int axisSize, Span<bool> resultSpan) where T : unmanaged
137+
{
138+
Span<T> inputSpan = nd.GetData().AsSpan<T>();
139+
140+
141+
for (int o = 0; o < resultSpan.Length; o++)
142+
{
143+
int blockIndex = o / postAxisStride;
144+
int inBlockIndex = o % postAxisStride;
145+
int inputStartIndex = blockIndex * axisSize * postAxisStride + inBlockIndex;
146+
147+
bool currentResult = true;
148+
for (int a = 0; a < axisSize; a++)
149+
{
150+
int inputIndex = inputStartIndex + a * postAxisStride;
151+
if (inputSpan[inputIndex].Equals(default(T)))
152+
{
153+
currentResult = false;
154+
break;
155+
}
156+
}
157+
resultSpan[o] = currentResult;
158+
}
159+
160+
return true;
60161
}
61162

62163
private static bool _all_linear<T>(NDArray<T> nd) where T : unmanaged

src/NumSharp.Core/Logic/np.any.cs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,111 @@ public static bool any(NDArray a)
5555
/// <param name="axis">Axis or axes along which a logical OR reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
5656
/// <returns>A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.</returns>
5757
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.any.html</remarks>
58-
public static NDArray<bool> any(NDArray nd, int axis)
58+
public static NDArray<bool> any(NDArray nd, int axis, bool keepdims)
5959
{
60-
throw new NotImplementedException(); //TODO
60+
if (axis < 0)
61+
axis = nd.ndim + axis;
62+
if (axis < 0 || axis >= nd.ndim)
63+
{
64+
throw new ArgumentOutOfRangeException(nameof(axis));
65+
}
66+
if (nd.ndim == 0)
67+
{
68+
throw new ArgumentException("Can't operate with zero array");
69+
}
70+
if (nd == null)
71+
{
72+
throw new ArgumentException("Can't operate with null array");
73+
}
74+
75+
int[] inputShape = nd.shape;
76+
int[] outputShape = new int[keepdims ? inputShape.Length : inputShape.Length - 1];
77+
int outputIndex = 0;
78+
for (int i = 0; i < inputShape.Length; i++)
79+
{
80+
if (i != axis)
81+
{
82+
outputShape[outputIndex++] = inputShape[i];
83+
}
84+
else if (keepdims)
85+
{
86+
outputShape[outputIndex++] = 1; // keep axis but length is one.
87+
}
88+
}
89+
90+
NDArray<bool> resultArray = (NDArray<bool>)zeros<bool>(outputShape);
91+
Span<bool> resultSpan = resultArray.GetData().AsSpan<bool>();
92+
93+
int axisSize = inputShape[axis];
94+
95+
// It help to build an index
96+
int preAxisStride = 1;
97+
for (int i = 0; i < axis; i++)
98+
{
99+
preAxisStride *= inputShape[i];
100+
}
101+
102+
int postAxisStride = 1;
103+
for (int i = axis + 1; i < inputShape.Length; i++)
104+
{
105+
postAxisStride *= inputShape[i];
106+
}
107+
108+
109+
// Operate different logic by TypeCode
110+
bool computationSuccess = false;
111+
switch (nd.typecode)
112+
{
113+
case NPTypeCode.Boolean: computationSuccess = ComputeAnyPerAxis<bool>(nd.MakeGeneric<bool>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
114+
case NPTypeCode.Byte: computationSuccess = ComputeAnyPerAxis<byte>(nd.MakeGeneric<byte>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
115+
case NPTypeCode.Int16: computationSuccess = ComputeAnyPerAxis<short>(nd.MakeGeneric<short>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
116+
case NPTypeCode.UInt16: computationSuccess = ComputeAnyPerAxis<ushort>(nd.MakeGeneric<ushort>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
117+
case NPTypeCode.Int32: computationSuccess = ComputeAnyPerAxis<int>(nd.MakeGeneric<int>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
118+
case NPTypeCode.UInt32: computationSuccess = ComputeAnyPerAxis<uint>(nd.MakeGeneric<uint>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
119+
case NPTypeCode.Int64: computationSuccess = ComputeAnyPerAxis<long>(nd.MakeGeneric<long>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
120+
case NPTypeCode.UInt64: computationSuccess = ComputeAnyPerAxis<ulong>(nd.MakeGeneric<ulong>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
121+
case NPTypeCode.Char: computationSuccess = ComputeAnyPerAxis<char>(nd.MakeGeneric<char>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
122+
case NPTypeCode.Double: computationSuccess = ComputeAnyPerAxis<double>(nd.MakeGeneric<double>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
123+
case NPTypeCode.Single: computationSuccess = ComputeAnyPerAxis<float>(nd.MakeGeneric<float>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
124+
case NPTypeCode.Decimal: computationSuccess = ComputeAnyPerAxis<decimal>(nd.MakeGeneric<decimal>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
125+
default:
126+
throw new NotSupportedException($"Type {nd.typecode} is not supported");
127+
}
128+
129+
130+
if (!computationSuccess)
131+
{
132+
throw new InvalidOperationException("Failed to compute all() along the specified axis");
133+
}
134+
135+
return resultArray;
136+
}
137+
138+
private static bool ComputeAnyPerAxis<T>(NDArray<T> nd, int axis, int preAxisStride, int postAxisStride, int axisSize, Span<bool> resultSpan) where T : unmanaged
139+
{
140+
Span<T> inputSpan = nd.GetData().AsSpan<T>();
141+
142+
143+
for (int o = 0; o < resultSpan.Length; o++)
144+
{
145+
int blockIndex = o / postAxisStride;
146+
int inBlockIndex = o % postAxisStride;
147+
int inputStartIndex = blockIndex * axisSize * postAxisStride + inBlockIndex;
148+
149+
bool currentResult = true;
150+
for (int a = 0; a < axisSize; a++)
151+
{
152+
int inputIndex = inputStartIndex + a * postAxisStride;
153+
if (inputSpan[inputIndex].Equals(default(T)))
154+
{
155+
currentResult = true;
156+
break;
157+
}
158+
}
159+
resultSpan[o] = currentResult;
160+
}
161+
162+
return false;
61163
}
62164

63165
private static bool _any_linear<T>(NDArray<T> nd) where T : unmanaged

src/NumSharp.Core/Utilities/ArrayConvert.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Numerics;
33
using System.Runtime.CompilerServices;
4+
using System.Text.RegularExpressions;
45
using System.Threading.Tasks;
56
using NumSharp.Backends;
67

@@ -4072,12 +4073,13 @@ public static Complex[] ToComplex(Decimal[] sourceArray)
40724073
Parallel.For(0, length, i => new Complex(Converts.ToDouble(sourceArray[i]), 0d));
40734074
return output;
40744075
}
4075-
4076+
40764077
/// <summary>
40774078
/// Converts <see cref="String"/> array to a <see cref="Complex"/> array.
40784079
/// </summary>
40794080
/// <param name="sourceArray">The array to convert</param>
40804081
/// <returns>Converted array of type Complex</returns>
4082+
/// <exception cref="FormatException">A string in sourceArray has an invalid complex format</exception>
40814083
[MethodImpl(MethodImplOptions.AggressiveInlining)]
40824084
public static Complex[] ToComplex(String[] sourceArray)
40834085
{
@@ -4087,7 +4089,16 @@ public static Complex[] ToComplex(String[] sourceArray)
40874089
var length = sourceArray.Length;
40884090
var output = new Complex[length];
40894091

4090-
Parallel.For(0, length, i => new Complex(Converts.ToDouble(sourceArray[i]), 0d));
4092+
Parallel.For(0, length, i =>
4093+
{
4094+
string input = sourceArray[i]?.Trim() ?? string.Empty;
4095+
if (string.IsNullOrEmpty(input))
4096+
{
4097+
output[i] = Complex.Zero; // NullString save as zero.
4098+
return;
4099+
}
4100+
var match = py.complex(sourceArray[i]);
4101+
});
40914102
return output;
40924103
}
40934104
#endif

src/NumSharp.Core/Utilities/py.cs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
namespace NumSharp.Utilities
1+
using System;
2+
using System.Numerics;
3+
using System.Text.RegularExpressions;
4+
5+
namespace NumSharp.Utilities
26
{
37
/// <summary>
48
/// Implements Python utility functions that are often used in connection with numpy
@@ -12,5 +16,47 @@ public static int[] range(int n)
1216
a[i] = i;
1317
return a;
1418
}
19+
/// <summary>
20+
/// 解析单个Python风格的复数字符串为Complex对象
21+
/// </summary>
22+
private static readonly Regex _pythonComplexRegex = new Regex(
23+
@"^(?<real>-?\d+(\.\d+)?)?((?<imagSign>\+|-)?(?<imag>\d+(\.\d+)?)?)?j$|^(?<onlyReal>-?\d+(\.\d+)?)$",
24+
RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.ExplicitCapture);
25+
public static Complex complex(string input)
26+
{
27+
var match = _pythonComplexRegex.Match(input);
28+
if (!match.Success)
29+
throw new FormatException($"Invalid Python complex format: '{input}'. Expected format like '10+5j', '3-2j', '4j' or '5'.");
30+
31+
// 解析仅实部的场景
32+
if (match.Groups["onlyReal"].Success)
33+
{
34+
double real = double.Parse(match.Groups["onlyReal"].Value);
35+
return new Complex(real, 0);
36+
}
37+
38+
// 解析实部(默认0)
39+
double realPart = 0;
40+
if (double.TryParse(match.Groups["real"].Value, out double r))
41+
realPart = r;
42+
43+
// 解析虚部(处理特殊情况:j / -j / +j)
44+
double imagPart = 0;
45+
string imagStr = match.Groups["imag"].Value;
46+
string imagSign = match.Groups["imagSign"].Value;
47+
48+
if (string.IsNullOrEmpty(imagStr) && !string.IsNullOrEmpty(input.TrimEnd('j', 'J')))
49+
{
50+
// 处理仅虚部的情况:j → 1j, -j → -1j, +j → 1j
51+
imagStr = "1";
52+
}
53+
54+
if (double.TryParse(imagStr, out double im))
55+
{
56+
imagPart = im * (imagSign == "-" ? -1 : 1);
57+
}
58+
59+
return new Complex(realPart, imagPart);
60+
}
1561
}
1662
}

0 commit comments

Comments
 (0)