11using System ;
2- using System . Linq ;
3- using System . Numerics ;
42using System . Runtime . CompilerServices ;
3+ using System . Runtime . InteropServices ;
4+ using System . Runtime . Intrinsics ;
55
6- namespace ManagedCode . Umap
6+ namespace ManagedCode . Umap ;
7+
8+ internal static class Simd
79{
8- internal static class SIMD < T >
10+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
11+ public static float Magnitude ( ReadOnlySpan < float > values ) => MathF . Sqrt ( DotProduct ( values , values ) ) ;
12+
13+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
14+ public static float Euclidean ( ReadOnlySpan < float > lhs , ReadOnlySpan < float > rhs )
915 {
10- private static readonly int _vs1 = Vector < float > . Count ;
11- private static readonly int _vs2 = 2 * Vector < float > . Count ;
12- private static readonly int _vs3 = 3 * Vector < float > . Count ;
13- private static readonly int _vs4 = 4 * Vector < float > . Count ;
16+ if ( lhs . Length != rhs . Length )
17+ {
18+ ThrowLengthMismatch ( ) ;
19+ }
1420
15- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
16- public static float Magnitude ( ref float [ ] vec ) => ( float ) Math . Sqrt ( DotProduct ( ref vec , ref vec ) ) ;
21+ ref float left = ref MemoryMarshal . GetReference ( lhs ) ;
22+ ref float right = ref MemoryMarshal . GetReference ( rhs ) ;
1723
18- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
19- public static float Euclidean ( ref float [ ] lhs , ref float [ ] rhs )
20- {
21- float result = 0f ;
24+ int length = lhs . Length ;
25+ int i = 0 ;
26+ float sum = 0f ;
2227
23- var count = lhs . Length ;
24- var offset = 0 ;
25- Vector < float > diff ;
26- while ( count >= _vs4 )
28+ if ( Vector256 . IsHardwareAccelerated && length >= Vector256 < float > . Count )
29+ {
30+ Vector256 < float > acc = Vector256 < float > . Zero ;
31+ for ( ; i <= length - Vector256 < float > . Count ; i += Vector256 < float > . Count )
2732 {
28- diff = new Vector < float > ( lhs , offset ) - new Vector < float > ( rhs , offset ) ; result += Vector . Dot ( diff , diff ) ;
29- diff = new Vector < float > ( lhs , offset + _vs1 ) - new Vector < float > ( rhs , offset + _vs1 ) ; result += Vector . Dot ( diff , diff ) ;
30- diff = new Vector < float > ( lhs , offset + _vs2 ) - new Vector < float > ( rhs , offset + _vs2 ) ; result += Vector . Dot ( diff , diff ) ;
31- diff = new Vector < float > ( lhs , offset + _vs3 ) - new Vector < float > ( rhs , offset + _vs3 ) ; result += Vector . Dot ( diff , diff ) ;
32- if ( count == _vs4 )
33- {
34- return result ;
35- }
36-
37- count -= _vs4 ;
38- offset += _vs4 ;
33+ var diff = Vector256 . LoadUnsafe ( ref left , ( uint ) i ) - Vector256 . LoadUnsafe ( ref right , ( uint ) i ) ;
34+ acc += diff * diff ;
3935 }
36+ sum += Vector256 . Sum ( acc ) ;
37+ }
4038
41- if ( count >= _vs2 )
42- {
43- diff = new Vector < float > ( lhs , offset ) - new Vector < float > ( rhs , offset ) ; result += Vector . Dot ( diff , diff ) ;
44- diff = new Vector < float > ( lhs , offset + _vs1 ) - new Vector < float > ( rhs , offset + _vs1 ) ; result += Vector . Dot ( diff , diff ) ;
45- if ( count == _vs2 )
46- {
47- return result ;
48- }
49-
50- count -= _vs2 ;
51- offset += _vs2 ;
52- }
53- if ( count >= _vs1 )
54- {
55- diff = new Vector < float > ( lhs , offset ) - new Vector < float > ( rhs , offset ) ; result += Vector . Dot ( diff , diff ) ;
56- if ( count == _vs1 )
57- {
58- return result ;
59- }
60-
61- count -= _vs1 ;
62- offset += _vs1 ;
63- }
64- if ( count > 0 )
39+ if ( Vector128 . IsHardwareAccelerated && i <= length - Vector128 < float > . Count )
40+ {
41+ Vector128 < float > acc = Vector128 < float > . Zero ;
42+ for ( ; i <= length - Vector128 < float > . Count ; i += Vector128 < float > . Count )
6543 {
66- while ( count > 0 )
67- {
68- var d = ( lhs [ offset ] - rhs [ offset ] ) ;
69- result += d * d ;
70- offset ++ ; count -- ;
71- }
44+ var diff = Vector128 . LoadUnsafe ( ref left , ( uint ) i ) - Vector128 . LoadUnsafe ( ref right , ( uint ) i ) ;
45+ acc += diff * diff ;
7246 }
73- return result ;
47+ sum += Vector128 . Sum ( acc ) ;
7448 }
7549
76- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
77- public static void Add ( ref float [ ] lhs , float f )
50+ for ( ; i < length ; i ++ )
7851 {
79- var count = lhs . Length ;
80- var offset = 0 ;
81- var v = new Vector < float > ( f ) ;
82- while ( count >= _vs4 )
83- {
84- ( new Vector < float > ( lhs , offset ) + v ) . CopyTo ( lhs , offset ) ;
85- ( new Vector < float > ( lhs , offset + _vs1 ) + v ) . CopyTo ( lhs , offset + _vs1 ) ;
86- ( new Vector < float > ( lhs , offset + _vs2 ) + v ) . CopyTo ( lhs , offset + _vs2 ) ;
87- ( new Vector < float > ( lhs , offset + _vs3 ) + v ) . CopyTo ( lhs , offset + _vs3 ) ;
88- if ( count == _vs4 )
89- {
90- return ;
91- }
92-
93- count -= _vs4 ;
94- offset += _vs4 ;
95- }
96- if ( count >= _vs2 )
97- {
98- ( new Vector < float > ( lhs , offset ) + v ) . CopyTo ( lhs , offset ) ;
99- ( new Vector < float > ( lhs , offset + _vs1 ) + v ) . CopyTo ( lhs , offset + _vs1 ) ;
100- if ( count == _vs2 )
101- {
102- return ;
103- }
104-
105- count -= _vs2 ;
106- offset += _vs2 ;
107- }
108- if ( count >= _vs1 )
109- {
110- ( new Vector < float > ( lhs , offset ) + v ) . CopyTo ( lhs , offset ) ;
111- if ( count == _vs1 )
112- {
113- return ;
114- }
115-
116- count -= _vs1 ;
117- offset += _vs1 ;
118- }
119- if ( count > 0 )
120- {
121- while ( count > 0 )
122- {
123- lhs [ offset ] += f ;
124- offset ++ ; count -- ;
125- }
126- }
52+ float diff = Unsafe . Add ( ref left , i ) - Unsafe . Add ( ref right , i ) ;
53+ sum += diff * diff ;
12754 }
12855
129- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
130- public static void Multiply ( ref float [ ] lhs , float f )
56+ return sum ;
57+ }
58+
59+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
60+ public static float DotProduct ( ReadOnlySpan < float > lhs , ReadOnlySpan < float > rhs )
61+ {
62+ if ( lhs . Length != rhs . Length )
13163 {
132- var count = lhs . Length ;
133- var offset = 0 ;
134- while ( count >= _vs4 )
135- {
136- ( new Vector < float > ( lhs , offset ) * f ) . CopyTo ( lhs , offset ) ;
137- ( new Vector < float > ( lhs , offset + _vs1 ) * f ) . CopyTo ( lhs , offset + _vs1 ) ;
138- ( new Vector < float > ( lhs , offset + _vs2 ) * f ) . CopyTo ( lhs , offset + _vs2 ) ;
139- ( new Vector < float > ( lhs , offset + _vs3 ) * f ) . CopyTo ( lhs , offset + _vs3 ) ;
140- if ( count == _vs4 )
141- {
142- return ;
143- }
144-
145- count -= _vs4 ;
146- offset += _vs4 ;
147- }
148- if ( count >= _vs2 )
149- {
150- ( new Vector < float > ( lhs , offset ) * f ) . CopyTo ( lhs , offset ) ;
151- ( new Vector < float > ( lhs , offset + _vs1 ) * f ) . CopyTo ( lhs , offset + _vs1 ) ;
152- if ( count == _vs2 )
153- {
154- return ;
155- }
156-
157- count -= _vs2 ;
158- offset += _vs2 ;
159- }
160- if ( count >= _vs1 )
64+ ThrowLengthMismatch ( ) ;
65+ }
66+
67+ ref float left = ref MemoryMarshal . GetReference ( lhs ) ;
68+ ref float right = ref MemoryMarshal . GetReference ( rhs ) ;
69+
70+ int length = lhs . Length ;
71+ int i = 0 ;
72+ float sum = 0f ;
73+
74+ if ( Vector256 . IsHardwareAccelerated && length >= Vector256 < float > . Count )
75+ {
76+ Vector256 < float > acc = Vector256 < float > . Zero ;
77+ for ( ; i <= length - Vector256 < float > . Count ; i += Vector256 < float > . Count )
16178 {
162- ( new Vector < float > ( lhs , offset ) * f ) . CopyTo ( lhs , offset ) ;
163- if ( count == _vs1 )
164- {
165- return ;
166- }
167-
168- count -= _vs1 ;
169- offset += _vs1 ;
79+ acc += Vector256 . LoadUnsafe ( ref left , ( uint ) i ) * Vector256 . LoadUnsafe ( ref right , ( uint ) i ) ;
17080 }
171- if ( count > 0 )
81+ sum += Vector256 . Sum ( acc ) ;
82+ }
83+
84+ if ( Vector128 . IsHardwareAccelerated && i <= length - Vector128 < float > . Count )
85+ {
86+ Vector128 < float > acc = Vector128 < float > . Zero ;
87+ for ( ; i <= length - Vector128 < float > . Count ; i += Vector128 < float > . Count )
17288 {
173- while ( count > 0 )
174- {
175- lhs [ offset ] *= f ;
176- offset ++ ; count -- ;
177- }
89+ acc += Vector128 . LoadUnsafe ( ref left , ( uint ) i ) * Vector128 . LoadUnsafe ( ref right , ( uint ) i ) ;
17890 }
91+ sum += Vector128 . Sum ( acc ) ;
17992 }
18093
181- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
182- public static float DotProduct ( ref float [ ] lhs , ref float [ ] rhs )
94+ for ( ; i < length ; i ++ )
18395 {
184- var result = 0f ;
185- var count = lhs . Length ;
186- var offset = 0 ;
187- while ( count >= _vs4 )
96+ sum += Unsafe . Add ( ref left , i ) * Unsafe . Add ( ref right , i ) ;
97+ }
98+
99+ return sum ;
100+ }
101+
102+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
103+ public static void Add ( Span < float > values , float scalar )
104+ {
105+ if ( values . Length == 0 )
106+ {
107+ return ;
108+ }
109+
110+ ref float start = ref MemoryMarshal . GetReference ( values ) ;
111+ int length = values . Length ;
112+ int i = 0 ;
113+
114+ if ( Vector256 . IsHardwareAccelerated && length >= Vector256 < float > . Count )
115+ {
116+ Vector256 < float > scalarVec = Vector256 . Create ( scalar ) ;
117+ for ( ; i <= length - Vector256 < float > . Count ; i += Vector256 < float > . Count )
188118 {
189- result += Vector . Dot ( new Vector < float > ( lhs , offset ) , new Vector < float > ( rhs , offset ) ) ;
190- result += Vector . Dot ( new Vector < float > ( lhs , offset + _vs1 ) , new Vector < float > ( rhs , offset + _vs1 ) ) ;
191- result += Vector . Dot ( new Vector < float > ( lhs , offset + _vs2 ) , new Vector < float > ( rhs , offset + _vs2 ) ) ;
192- result += Vector . Dot ( new Vector < float > ( lhs , offset + _vs3 ) , new Vector < float > ( rhs , offset + _vs3 ) ) ;
193- if ( count == _vs4 )
194- {
195- return result ;
196- }
197-
198- count -= _vs4 ;
199- offset += _vs4 ;
119+ var current = Vector256 . LoadUnsafe ( ref start , ( uint ) i ) ;
120+ ( current + scalarVec ) . StoreUnsafe ( ref start , ( uint ) i ) ;
200121 }
201- if ( count >= _vs2 )
122+ }
123+
124+ if ( Vector128 . IsHardwareAccelerated && i <= length - Vector128 < float > . Count )
125+ {
126+ Vector128 < float > scalarVec = Vector128 . Create ( scalar ) ;
127+ for ( ; i <= length - Vector128 < float > . Count ; i += Vector128 < float > . Count )
202128 {
203- result += Vector . Dot ( new Vector < float > ( lhs , offset ) , new Vector < float > ( rhs , offset ) ) ;
204- result += Vector . Dot ( new Vector < float > ( lhs , offset + _vs1 ) , new Vector < float > ( rhs , offset + _vs1 ) ) ;
205- if ( count == _vs2 )
206- {
207- return result ;
208- }
209-
210- count -= _vs2 ;
211- offset += _vs2 ;
129+ var current = Vector128 . LoadUnsafe ( ref start , ( uint ) i ) ;
130+ ( current + scalarVec ) . StoreUnsafe ( ref start , ( uint ) i ) ;
212131 }
213- if ( count >= _vs1 )
132+ }
133+
134+ for ( ; i < length ; i ++ )
135+ {
136+ Unsafe . Add ( ref start , i ) + = scalar ;
137+ }
138+ }
139+
140+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
141+ public static void Multiply ( Span < float > values , float scalar )
142+ {
143+ if ( values . Length == 0 )
144+ {
145+ return ;
146+ }
147+
148+ ref float start = ref MemoryMarshal . GetReference ( values ) ;
149+ int length = values . Length ;
150+ int i = 0 ;
151+
152+ if ( Vector256 . IsHardwareAccelerated && length >= Vector256 < float > . Count )
153+ {
154+ Vector256 < float > scalarVec = Vector256 . Create ( scalar ) ;
155+ for ( ; i <= length - Vector256 < float > . Count ; i += Vector256 < float > . Count )
214156 {
215- result += Vector . Dot ( new Vector < float > ( lhs , offset ) , new Vector < float > ( rhs , offset ) ) ;
216- if ( count == _vs1 )
217- {
218- return result ;
219- }
220-
221- count -= _vs1 ;
222- offset += _vs1 ;
157+ var current = Vector256 . LoadUnsafe ( ref start , ( uint ) i ) ;
158+ ( current * scalarVec ) . StoreUnsafe ( ref start , ( uint ) i ) ;
223159 }
224- if ( count > 0 )
160+ }
161+
162+ if ( Vector128 . IsHardwareAccelerated && i <= length - Vector128 < float > . Count )
163+ {
164+ Vector128 < float > scalarVec = Vector128 . Create ( scalar ) ;
165+ for ( ; i <= length - Vector128 < float > . Count ; i += Vector128 < float > . Count )
225166 {
226- while ( count > 0 )
227- {
228- result += lhs [ offset ] * rhs [ offset ] ;
229- offset ++ ; count -- ;
230- }
167+ var current = Vector128 . LoadUnsafe ( ref start , ( uint ) i ) ;
168+ ( current * scalarVec ) . StoreUnsafe ( ref start , ( uint ) i ) ;
231169 }
232- return result ;
170+ }
171+
172+ for ( ; i < length ; i ++ )
173+ {
174+ Unsafe . Add ( ref start , i ) * = scalar ;
233175 }
234176 }
235- }
177+
178+ [ MethodImpl ( MethodImplOptions . NoInlining ) ]
179+ private static void ThrowLengthMismatch ( ) => throw new ArgumentException ( "Vectors must have the same length." ) ;
180+ }
0 commit comments