Skip to content

Commit e226317

Browse files
committed
update dontet stuff
1 parent ad293e2 commit e226317

File tree

5 files changed

+217
-315
lines changed

5 files changed

+217
-315
lines changed

ManagedCode.Umap/SIMD.cs

Lines changed: 144 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -1,235 +1,180 @@
11
using System;
2-
using System.Linq;
3-
using System.Numerics;
42
using System.Runtime.CompilerServices;
3+
using System.Runtime.InteropServices;
4+
using System.Runtime.Intrinsics;
55

6-
namespace ManagedCode.Umap
6+
namespace ManagedCode.Umap;
7+
8+
internal static class Simd
79
{
8-
internal static class SIMD<T>
10+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
11+
public static float Magnitude(ReadOnlySpan<float> values) => MathF.Sqrt(DotProduct(values, values));
12+
13+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
14+
public static float Euclidean(ReadOnlySpan<float> lhs, ReadOnlySpan<float> rhs)
915
{
10-
private static readonly int _vs1 = Vector<float>.Count;
11-
private static readonly int _vs2 = 2 * Vector<float>.Count;
12-
private static readonly int _vs3 = 3 * Vector<float>.Count;
13-
private static readonly int _vs4 = 4 * Vector<float>.Count;
16+
if (lhs.Length != rhs.Length)
17+
{
18+
ThrowLengthMismatch();
19+
}
1420

15-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
16-
public static float Magnitude(ref float[] vec) => (float)Math.Sqrt(DotProduct(ref vec, ref vec));
21+
ref float left = ref MemoryMarshal.GetReference(lhs);
22+
ref float right = ref MemoryMarshal.GetReference(rhs);
1723

18-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
19-
public static float Euclidean(ref float[] lhs, ref float[] rhs)
20-
{
21-
float result = 0f;
24+
int length = lhs.Length;
25+
int i = 0;
26+
float sum = 0f;
2227

23-
var count = lhs.Length;
24-
var offset = 0;
25-
Vector<float> diff;
26-
while (count >= _vs4)
28+
if (Vector256.IsHardwareAccelerated && length >= Vector256<float>.Count)
29+
{
30+
Vector256<float> acc = Vector256<float>.Zero;
31+
for (; i <= length - Vector256<float>.Count; i += Vector256<float>.Count)
2732
{
28-
diff = new Vector<float>(lhs, offset) - new Vector<float>(rhs, offset); result += Vector.Dot(diff, diff);
29-
diff = new Vector<float>(lhs, offset + _vs1) - new Vector<float>(rhs, offset + _vs1); result += Vector.Dot(diff, diff);
30-
diff = new Vector<float>(lhs, offset + _vs2) - new Vector<float>(rhs, offset + _vs2); result += Vector.Dot(diff, diff);
31-
diff = new Vector<float>(lhs, offset + _vs3) - new Vector<float>(rhs, offset + _vs3); result += Vector.Dot(diff, diff);
32-
if (count == _vs4)
33-
{
34-
return result;
35-
}
36-
37-
count -= _vs4;
38-
offset += _vs4;
33+
var diff = Vector256.LoadUnsafe(ref left, (uint)i) - Vector256.LoadUnsafe(ref right, (uint)i);
34+
acc += diff * diff;
3935
}
36+
sum += Vector256.Sum(acc);
37+
}
4038

41-
if (count >= _vs2)
42-
{
43-
diff = new Vector<float>(lhs, offset) - new Vector<float>(rhs, offset); result += Vector.Dot(diff, diff);
44-
diff = new Vector<float>(lhs, offset + _vs1) - new Vector<float>(rhs, offset + _vs1); result += Vector.Dot(diff, diff);
45-
if (count == _vs2)
46-
{
47-
return result;
48-
}
49-
50-
count -= _vs2;
51-
offset += _vs2;
52-
}
53-
if (count >= _vs1)
54-
{
55-
diff = new Vector<float>(lhs, offset) - new Vector<float>(rhs, offset); result += Vector.Dot(diff, diff);
56-
if (count == _vs1)
57-
{
58-
return result;
59-
}
60-
61-
count -= _vs1;
62-
offset += _vs1;
63-
}
64-
if (count > 0)
39+
if (Vector128.IsHardwareAccelerated && i <= length - Vector128<float>.Count)
40+
{
41+
Vector128<float> acc = Vector128<float>.Zero;
42+
for (; i <= length - Vector128<float>.Count; i += Vector128<float>.Count)
6543
{
66-
while (count > 0)
67-
{
68-
var d = (lhs[offset] - rhs[offset]);
69-
result += d * d;
70-
offset++; count--;
71-
}
44+
var diff = Vector128.LoadUnsafe(ref left, (uint)i) - Vector128.LoadUnsafe(ref right, (uint)i);
45+
acc += diff * diff;
7246
}
73-
return result;
47+
sum += Vector128.Sum(acc);
7448
}
7549

76-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
77-
public static void Add(ref float[] lhs, float f)
50+
for (; i < length; i++)
7851
{
79-
var count = lhs.Length;
80-
var offset = 0;
81-
var v = new Vector<float>(f);
82-
while (count >= _vs4)
83-
{
84-
(new Vector<float>(lhs, offset) + v).CopyTo(lhs, offset);
85-
(new Vector<float>(lhs, offset + _vs1) + v).CopyTo(lhs, offset + _vs1);
86-
(new Vector<float>(lhs, offset + _vs2) + v).CopyTo(lhs, offset + _vs2);
87-
(new Vector<float>(lhs, offset + _vs3) + v).CopyTo(lhs, offset + _vs3);
88-
if (count == _vs4)
89-
{
90-
return;
91-
}
92-
93-
count -= _vs4;
94-
offset += _vs4;
95-
}
96-
if (count >= _vs2)
97-
{
98-
(new Vector<float>(lhs, offset) + v).CopyTo(lhs, offset);
99-
(new Vector<float>(lhs, offset + _vs1) + v).CopyTo(lhs, offset + _vs1);
100-
if (count == _vs2)
101-
{
102-
return;
103-
}
104-
105-
count -= _vs2;
106-
offset += _vs2;
107-
}
108-
if (count >= _vs1)
109-
{
110-
(new Vector<float>(lhs, offset) + v).CopyTo(lhs, offset);
111-
if (count == _vs1)
112-
{
113-
return;
114-
}
115-
116-
count -= _vs1;
117-
offset += _vs1;
118-
}
119-
if (count > 0)
120-
{
121-
while (count > 0)
122-
{
123-
lhs[offset] += f;
124-
offset++; count--;
125-
}
126-
}
52+
float diff = Unsafe.Add(ref left, i) - Unsafe.Add(ref right, i);
53+
sum += diff * diff;
12754
}
12855

129-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
130-
public static void Multiply(ref float[] lhs, float f)
56+
return sum;
57+
}
58+
59+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
60+
public static float DotProduct(ReadOnlySpan<float> lhs, ReadOnlySpan<float> rhs)
61+
{
62+
if (lhs.Length != rhs.Length)
13163
{
132-
var count = lhs.Length;
133-
var offset = 0;
134-
while (count >= _vs4)
135-
{
136-
(new Vector<float>(lhs, offset) * f).CopyTo(lhs, offset);
137-
(new Vector<float>(lhs, offset + _vs1) * f).CopyTo(lhs, offset + _vs1);
138-
(new Vector<float>(lhs, offset + _vs2) * f).CopyTo(lhs, offset + _vs2);
139-
(new Vector<float>(lhs, offset + _vs3) * f).CopyTo(lhs, offset + _vs3);
140-
if (count == _vs4)
141-
{
142-
return;
143-
}
144-
145-
count -= _vs4;
146-
offset += _vs4;
147-
}
148-
if (count >= _vs2)
149-
{
150-
(new Vector<float>(lhs, offset) * f).CopyTo(lhs, offset);
151-
(new Vector<float>(lhs, offset + _vs1) * f).CopyTo(lhs, offset + _vs1);
152-
if (count == _vs2)
153-
{
154-
return;
155-
}
156-
157-
count -= _vs2;
158-
offset += _vs2;
159-
}
160-
if (count >= _vs1)
64+
ThrowLengthMismatch();
65+
}
66+
67+
ref float left = ref MemoryMarshal.GetReference(lhs);
68+
ref float right = ref MemoryMarshal.GetReference(rhs);
69+
70+
int length = lhs.Length;
71+
int i = 0;
72+
float sum = 0f;
73+
74+
if (Vector256.IsHardwareAccelerated && length >= Vector256<float>.Count)
75+
{
76+
Vector256<float> acc = Vector256<float>.Zero;
77+
for (; i <= length - Vector256<float>.Count; i += Vector256<float>.Count)
16178
{
162-
(new Vector<float>(lhs, offset) * f).CopyTo(lhs, offset);
163-
if (count == _vs1)
164-
{
165-
return;
166-
}
167-
168-
count -= _vs1;
169-
offset += _vs1;
79+
acc += Vector256.LoadUnsafe(ref left, (uint)i) * Vector256.LoadUnsafe(ref right, (uint)i);
17080
}
171-
if (count > 0)
81+
sum += Vector256.Sum(acc);
82+
}
83+
84+
if (Vector128.IsHardwareAccelerated && i <= length - Vector128<float>.Count)
85+
{
86+
Vector128<float> acc = Vector128<float>.Zero;
87+
for (; i <= length - Vector128<float>.Count; i += Vector128<float>.Count)
17288
{
173-
while (count > 0)
174-
{
175-
lhs[offset] *= f;
176-
offset++; count--;
177-
}
89+
acc += Vector128.LoadUnsafe(ref left, (uint)i) * Vector128.LoadUnsafe(ref right, (uint)i);
17890
}
91+
sum += Vector128.Sum(acc);
17992
}
18093

181-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
182-
public static float DotProduct(ref float[] lhs, ref float[] rhs)
94+
for (; i < length; i++)
18395
{
184-
var result = 0f;
185-
var count = lhs.Length;
186-
var offset = 0;
187-
while (count >= _vs4)
96+
sum += Unsafe.Add(ref left, i) * Unsafe.Add(ref right, i);
97+
}
98+
99+
return sum;
100+
}
101+
102+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
103+
public static void Add(Span<float> values, float scalar)
104+
{
105+
if (values.Length == 0)
106+
{
107+
return;
108+
}
109+
110+
ref float start = ref MemoryMarshal.GetReference(values);
111+
int length = values.Length;
112+
int i = 0;
113+
114+
if (Vector256.IsHardwareAccelerated && length >= Vector256<float>.Count)
115+
{
116+
Vector256<float> scalarVec = Vector256.Create(scalar);
117+
for (; i <= length - Vector256<float>.Count; i += Vector256<float>.Count)
188118
{
189-
result += Vector.Dot(new Vector<float>(lhs, offset), new Vector<float>(rhs, offset));
190-
result += Vector.Dot(new Vector<float>(lhs, offset + _vs1), new Vector<float>(rhs, offset + _vs1));
191-
result += Vector.Dot(new Vector<float>(lhs, offset + _vs2), new Vector<float>(rhs, offset + _vs2));
192-
result += Vector.Dot(new Vector<float>(lhs, offset + _vs3), new Vector<float>(rhs, offset + _vs3));
193-
if (count == _vs4)
194-
{
195-
return result;
196-
}
197-
198-
count -= _vs4;
199-
offset += _vs4;
119+
var current = Vector256.LoadUnsafe(ref start, (uint)i);
120+
(current + scalarVec).StoreUnsafe(ref start, (uint)i);
200121
}
201-
if (count >= _vs2)
122+
}
123+
124+
if (Vector128.IsHardwareAccelerated && i <= length - Vector128<float>.Count)
125+
{
126+
Vector128<float> scalarVec = Vector128.Create(scalar);
127+
for (; i <= length - Vector128<float>.Count; i += Vector128<float>.Count)
202128
{
203-
result += Vector.Dot(new Vector<float>(lhs, offset), new Vector<float>(rhs, offset));
204-
result += Vector.Dot(new Vector<float>(lhs, offset + _vs1), new Vector<float>(rhs, offset + _vs1));
205-
if (count == _vs2)
206-
{
207-
return result;
208-
}
209-
210-
count -= _vs2;
211-
offset += _vs2;
129+
var current = Vector128.LoadUnsafe(ref start, (uint)i);
130+
(current + scalarVec).StoreUnsafe(ref start, (uint)i);
212131
}
213-
if (count >= _vs1)
132+
}
133+
134+
for (; i < length; i++)
135+
{
136+
Unsafe.Add(ref start, i) += scalar;
137+
}
138+
}
139+
140+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
141+
public static void Multiply(Span<float> values, float scalar)
142+
{
143+
if (values.Length == 0)
144+
{
145+
return;
146+
}
147+
148+
ref float start = ref MemoryMarshal.GetReference(values);
149+
int length = values.Length;
150+
int i = 0;
151+
152+
if (Vector256.IsHardwareAccelerated && length >= Vector256<float>.Count)
153+
{
154+
Vector256<float> scalarVec = Vector256.Create(scalar);
155+
for (; i <= length - Vector256<float>.Count; i += Vector256<float>.Count)
214156
{
215-
result += Vector.Dot(new Vector<float>(lhs, offset), new Vector<float>(rhs, offset));
216-
if (count == _vs1)
217-
{
218-
return result;
219-
}
220-
221-
count -= _vs1;
222-
offset += _vs1;
157+
var current = Vector256.LoadUnsafe(ref start, (uint)i);
158+
(current * scalarVec).StoreUnsafe(ref start, (uint)i);
223159
}
224-
if (count > 0)
160+
}
161+
162+
if (Vector128.IsHardwareAccelerated && i <= length - Vector128<float>.Count)
163+
{
164+
Vector128<float> scalarVec = Vector128.Create(scalar);
165+
for (; i <= length - Vector128<float>.Count; i += Vector128<float>.Count)
225166
{
226-
while (count > 0)
227-
{
228-
result += lhs[offset] * rhs[offset];
229-
offset++; count--;
230-
}
167+
var current = Vector128.LoadUnsafe(ref start, (uint)i);
168+
(current * scalarVec).StoreUnsafe(ref start, (uint)i);
231169
}
232-
return result;
170+
}
171+
172+
for (; i < length; i++)
173+
{
174+
Unsafe.Add(ref start, i) *= scalar;
233175
}
234176
}
235-
}
177+
178+
[MethodImpl(MethodImplOptions.NoInlining)]
179+
private static void ThrowLengthMismatch() => throw new ArgumentException("Vectors must have the same length.");
180+
}

0 commit comments

Comments
 (0)