2121import java .lang .invoke .MethodType ;
2222import java .util .concurrent .ThreadLocalRandom ;
2323import java .util .concurrent .TimeUnit ;
24+ import org .apache .lucene .internal .vectorization .VectorUtilSupport ;
25+ import org .apache .lucene .internal .vectorization .VectorizationProvider ;
2426import org .apache .lucene .util .VectorUtil ;
2527import org .openjdk .jmh .annotations .*;
2628
3638 value = 3 ,
3739 jvmArgsAppend = {"-Xmx2g" , "-Xms2g" , "-XX:+AlwaysPreTouch" })
3840public class VectorUtilBenchmark {
41+
42+ /**
43+ * Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
44+ * b). The method above will use a native C implementation of dotProduct if it is enabled via
45+ * {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
46+ * arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
47+ * taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
48+ *
49+ * @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
50+ */
51+ private static MethodHandle nativeDotProductHandle (String methodName ) {
52+ if (Runtime .version ().feature () < 21 ) {
53+ return null ;
54+ }
55+ try {
56+ final VectorUtilSupport vectorUtilSupport =
57+ VectorizationProvider .getInstance ().getVectorUtilSupport ();
58+ if (vectorUtilSupport .getClass ().getName ().endsWith ("PanamaVectorUtilSupport" )) {
59+ MethodHandles .Lookup lookup = MethodHandles .lookup ();
60+ // A method type that computes dot-product between two off-heap vectors
61+ // provided as native MemorySegment and returns an int score.
62+ final var MemorySegment = "java.lang.foreign.MemorySegment" ;
63+ final var methodType =
64+ MethodType .methodType (
65+ int .class , lookup .findClass (MemorySegment ), lookup .findClass (MemorySegment ));
66+ var mh = lookup .findStatic (vectorUtilSupport .getClass (), methodName , methodType );
67+ // Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
68+ // WrongMethodException.
69+ // Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
70+ // API.
71+ // This minimizes the reflection overhead and brings us very close to the performance of
72+ // direct method invocation.
73+ mh = mh .asType (mh .type ().changeParameterType (0 , Object .class ));
74+ mh = mh .asType (mh .type ().changeParameterType (1 , Object .class ));
75+ return mh ;
76+ }
77+ } catch (Throwable e ) {
78+ throw new RuntimeException (e );
79+ }
80+ return null ;
81+ }
82+
83+ /**
84+ * Get randomly initialized byte-vectors of given size in off-heap MemorySegment
85+ *
86+ * @param size dimension of byte-vector
87+ * @return Object MemorySegment
88+ */
89+ private static Object getOffHeapByteVector (int size ) {
90+ try {
91+ VectorizationProvider vectorizationProvider = VectorizationProvider .getInstance ();
92+ if (vectorizationProvider .getClass ().getName ().endsWith ("PanamaVectorizationProvider" )) {
93+ MethodHandles .Lookup lookup = MethodHandles .lookup ();
94+ // A method type that accepts numBytes and returns an off-heap vector of size 'numBytes'
95+ // where each byte is randomly initialized
96+ final var methodType =
97+ MethodType .methodType (lookup .findClass ("java.lang.foreign.MemorySegment" ), int .class );
98+ // The class is expected to be "PanamaVectorUtilSupport" with a static method
99+ // "MemorySegment offHeapByteVector(int numBytes)" that returns the off-heap vector as a
100+ // MemorySegment
101+ Class <?> vectorUtilSupportClass = vectorizationProvider .getVectorUtilSupport ().getClass ();
102+ final MethodHandle offHeapByteVector =
103+ lookup .findStatic (vectorUtilSupportClass , "offHeapByteVector" , methodType );
104+ return offHeapByteVector .invoke (size );
105+ }
106+ } catch (Throwable e ) {
107+ throw new RuntimeException (e );
108+ }
109+ return null ;
110+ }
111+
112+ private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle ("dotProduct" );
113+ private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
114+ nativeDotProductHandle ("simpleNativeDotProduct" );
115+
39116 static void compressBytes (byte [] raw , byte [] compressed ) {
40117 for (int i = 0 ; i < compressed .length ; ++i ) {
41118 int v = (raw [i ] << 4 ) | raw [compressed .length + i ];
@@ -52,8 +129,8 @@ static void compressBytes(byte[] raw, byte[] compressed) {
52129 private float [] floatsB ;
53130 private int expectedhalfByteDotProduct ;
54131
55- private Object nativeBytesA ;
56- private Object nativeBytesB ;
132+ private Object offHeapBytesA ;
133+ private Object offHeapBytesB ;
57134
58135 /** private Object nativeBytesA; private Object nativeBytesB; */
59136 @ Param ({"1" , "128" , "207" , "256" , "300" , "512" , "702" , "1024" })
@@ -94,70 +171,26 @@ public void init() {
94171 // Java 21+ specific initialization
95172 final int runtimeVersion = Runtime .version ().feature ();
96173 if (runtimeVersion >= 21 ) {
97- // Reflection based code to eliminate the use of Preview classes in JMH benchmarks
98- try {
99- final Class <?> vectorUtilSupportClass = VectorUtil .getVectorUtilSupportClass ();
100- final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport" ;
101- if (vectorUtilSupportClass .getName ().equals (className ) == false ) {
102- nativeBytesA = null ;
103- nativeBytesB = null ;
104- } else {
105- MethodHandles .Lookup lookup = MethodHandles .lookup ();
106- final var MemorySegment = "java.lang.foreign.MemorySegment" ;
107- final var methodType =
108- MethodType .methodType (lookup .findClass (MemorySegment ), byte [].class );
109- MethodHandle nativeMemorySegment =
110- lookup .findStatic (vectorUtilSupportClass , "nativeMemorySegment" , methodType );
111- byte [] a = new byte [size ];
112- byte [] b = new byte [size ];
113- for (int i = 0 ; i < size ; ++i ) {
114- a [i ] = (byte ) random .nextInt (128 );
115- b [i ] = (byte ) random .nextInt (128 );
116- }
117- nativeBytesA = nativeMemorySegment .invoke (a );
118- nativeBytesB = nativeMemorySegment .invoke (b );
119- }
120- } catch (Throwable e ) {
121- throw new RuntimeException (e );
122- }
123- /*
124- Arena offHeap = Arena.ofAuto();
125- nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
126- nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
127- for (int i = 0; i < size; ++i) {
128- nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
129- nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
130- }*/
174+ offHeapBytesA = getOffHeapByteVector (size );
175+ offHeapBytesB = getOffHeapByteVector (size );
176+ }
177+ }
178+
179+ @ Benchmark
180+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
181+ public int dot8s () {
182+ try {
183+ return (int ) NATIVE_DOT_PRODUCT .invokeExact (offHeapBytesA , offHeapBytesB );
184+ } catch (Throwable e ) {
185+ throw new RuntimeException (e );
131186 }
132187 }
133188
134- /**
135- * High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
136- * Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
137- * PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
138- * bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
139- * significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
140- * benchmark run to avoid boxing/unboxing overhead.
141- *
142- * <pre>
143- * Benchmark (size) Mode Cnt Score Error Units
144- * VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
145- * </pre>
146- *
147- * Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
148- * and exact method invocation was made as shown below <b>return (int)
149- * VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
150- *
151- * <pre>
152- * Benchmark (size) Mode Cnt Score Error Units
153- * VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
154- * </pre>
155- */
156189 @ Benchmark
157190 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
158- public void dot8s () {
191+ public int simpleDot8s () {
159192 try {
160- VectorUtil . NATIVE_DOT_PRODUCT . invoke ( nativeBytesA , nativeBytesB );
193+ return ( int ) SIMPLE_NATIVE_DOT_PRODUCT . invokeExact ( offHeapBytesA , offHeapBytesB );
161194 } catch (Throwable e ) {
162195 throw new RuntimeException (e );
163196 }
0 commit comments