Skip to content

Commit ec7c6d0

Browse files
author
Ankur Goel
committed
Allocate offHeap memory in dotProduct(byte[], byte[]) for unit tests if native dot-product is enabled. Simplifyy JMH benchmark code that tests native dot product. Incorporate other review feedback
1 parent 7349d4f commit ec7c6d0

25 files changed

Lines changed: 435 additions & 307 deletions

.github/workflows/run-checks-all.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
matrix:
3131
os: [ ubuntu-latest ]
3232
java: [ '21' ]
33+
compiler: [ gcc ]
3334

3435
runs-on: ${{ matrix.os }}
3536

@@ -38,6 +39,8 @@ jobs:
3839
- uses: ./.github/actions/prepare-for-build
3940

4041
- name: Run gradle check (without tests)
42+
env:
43+
CC: ${{ matrix.compiler }}
4144
run: ./gradlew check -x test -Ptask.times=true --max-workers 2
4245

4346

@@ -53,6 +56,7 @@ jobs:
5356
# macos-latest: a tad slower than ubuntu and pretty much the same (?) so leaving out.
5457
os: [ ubuntu-latest ]
5558
java: [ '21' ]
59+
compiler: [ gcc ]
5660

5761
runs-on: ${{ matrix.os }}
5862

@@ -61,6 +65,8 @@ jobs:
6165
- uses: ./.github/actions/prepare-for-build
6266

6367
- name: Run gradle tests
68+
env:
69+
CC: ${{ matrix.compiler }}
6470
run: ./gradlew test "-Ptask.times=true" --max-workers 2
6571

6672
- name: List automatically-initialized gradle.properties

gradle/testing/randomization.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ allprojects {
112112
[propName: 'tests.forceintegervectors',
113113
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize != 'default') },
114114
description: "Forces use of integer vectors even when slow."],
115+
// test native dot-product when running with Java 21 or greater and 'default' vector size (chosen by randomized testing)
116+
[propName: 'test.native.dotProduct',
117+
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize == 'default' && rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion))}],
115118
[propName: 'tests.defaultvectorization', value: false,
116119
description: "Uses defaults for running tests with correct JVM settings to test Panama vectorization (tests.jvmargs, tests.vectorsize, tests.forceintegervectors)."],
117120
]

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.lang.invoke.MethodType;
2222
import java.util.concurrent.ThreadLocalRandom;
2323
import java.util.concurrent.TimeUnit;
24+
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
25+
import org.apache.lucene.internal.vectorization.VectorizationProvider;
2426
import org.apache.lucene.util.VectorUtil;
2527
import org.openjdk.jmh.annotations.*;
2628

@@ -36,6 +38,82 @@
3638
value = 3,
3739
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
3840
public class VectorUtilBenchmark {
41+
42+
/**
43+
* Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
44+
* b). The method above will use a native C implementation of dotProduct if it is enabled via
45+
* {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
46+
* arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
47+
* taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
48+
*
49+
* @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
50+
*/
51+
private static MethodHandle nativeDotProductHandle(String methodName) {
52+
if (Runtime.version().feature() < 21) {
53+
return null;
54+
}
55+
try {
56+
final VectorUtilSupport vectorUtilSupport =
57+
VectorizationProvider.getInstance().getVectorUtilSupport();
58+
if (vectorUtilSupport.getClass().getName().endsWith("PanamaVectorUtilSupport")) {
59+
MethodHandles.Lookup lookup = MethodHandles.lookup();
60+
// A method type that computes dot-product between two off-heap vectors
61+
// provided as native MemorySegment and returns an int score.
62+
final var MemorySegment = "java.lang.foreign.MemorySegment";
63+
final var methodType =
64+
MethodType.methodType(
65+
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
66+
var mh = lookup.findStatic(vectorUtilSupport.getClass(), methodName, methodType);
67+
// Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
68+
// WrongMethodException.
69+
// Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
70+
// API.
71+
// This minimizes the reflection overhead and brings us very close to the performance of
72+
// direct method invocation.
73+
mh = mh.asType(mh.type().changeParameterType(0, Object.class));
74+
mh = mh.asType(mh.type().changeParameterType(1, Object.class));
75+
return mh;
76+
}
77+
} catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
78+
throw new RuntimeException(e);
79+
}
80+
return null;
81+
}
82+
83+
/**
84+
* Copy input byte[] to off-heap MemorySegment
85+
*
86+
* @param byteVector to be copied off-heap
87+
* @return Object MemorySegment
88+
*/
89+
private static Object getOffHeapByteVector(byte[] byteVector) {
90+
try {
91+
VectorizationProvider vectorizationProvider = VectorizationProvider.getInstance();
92+
if (vectorizationProvider.getClass().getName().endsWith("PanamaVectorizationProvider")) {
93+
MethodHandles.Lookup lookup = MethodHandles.lookup();
94+
// A method type that copies input byte[] to an off-heap MemorySegment
95+
final var methodType =
96+
MethodType.methodType(
97+
lookup.findClass("java.lang.foreign.MemorySegment"), byte[].class);
98+
// The class is expected to be "PanamaVectorUtilSupport" with a static method
99+
// "MemorySegment offHeapByteVector(byte[] byteVector)" that returns the off-heap vector as
100+
// a
101+
// MemorySegment
102+
Class<?> vectorUtilSupportClass = vectorizationProvider.getVectorUtilSupport().getClass();
103+
final MethodHandle offHeapByteVector =
104+
lookup.findStatic(vectorUtilSupportClass, "offHeapByteVector", methodType);
105+
return offHeapByteVector.invoke(byteVector);
106+
}
107+
} catch (Throwable e) {
108+
throw new RuntimeException(e);
109+
}
110+
return null;
111+
}
112+
113+
private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle("dotProduct");
114+
private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
115+
nativeDotProductHandle("simpleNativeDotProduct");
116+
39117
static void compressBytes(byte[] raw, byte[] compressed) {
40118
for (int i = 0; i < compressed.length; ++i) {
41119
int v = (raw[i] << 4) | raw[compressed.length + i];
@@ -52,8 +130,8 @@ static void compressBytes(byte[] raw, byte[] compressed) {
52130
private float[] floatsB;
53131
private int expectedhalfByteDotProduct;
54132

55-
private Object nativeBytesA;
56-
private Object nativeBytesB;
133+
private Object offHeapBytesA;
134+
private Object offHeapBytesB;
57135

58136
/** private Object nativeBytesA; private Object nativeBytesB; */
59137
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
@@ -94,70 +172,26 @@ public void init() {
94172
// Java 21+ specific initialization
95173
final int runtimeVersion = Runtime.version().feature();
96174
if (runtimeVersion >= 21) {
97-
// Reflection based code to eliminate the use of Preview classes in JMH benchmarks
98-
try {
99-
final Class<?> vectorUtilSupportClass = VectorUtil.getVectorUtilSupportClass();
100-
final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
101-
if (vectorUtilSupportClass.getName().equals(className) == false) {
102-
nativeBytesA = null;
103-
nativeBytesB = null;
104-
} else {
105-
MethodHandles.Lookup lookup = MethodHandles.lookup();
106-
final var MemorySegment = "java.lang.foreign.MemorySegment";
107-
final var methodType =
108-
MethodType.methodType(lookup.findClass(MemorySegment), byte[].class);
109-
MethodHandle nativeMemorySegment =
110-
lookup.findStatic(vectorUtilSupportClass, "nativeMemorySegment", methodType);
111-
byte[] a = new byte[size];
112-
byte[] b = new byte[size];
113-
for (int i = 0; i < size; ++i) {
114-
a[i] = (byte) random.nextInt(128);
115-
b[i] = (byte) random.nextInt(128);
116-
}
117-
nativeBytesA = nativeMemorySegment.invoke(a);
118-
nativeBytesB = nativeMemorySegment.invoke(b);
119-
}
120-
} catch (Throwable e) {
121-
throw new RuntimeException(e);
122-
}
123-
/*
124-
Arena offHeap = Arena.ofAuto();
125-
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
126-
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
127-
for (int i = 0; i < size; ++i) {
128-
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
129-
nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
130-
}*/
175+
offHeapBytesA = getOffHeapByteVector(bytesA);
176+
offHeapBytesB = getOffHeapByteVector(bytesB);
177+
}
178+
}
179+
180+
@Benchmark
181+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
182+
public int dot8s() {
183+
try {
184+
return (int) NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
185+
} catch (Throwable e) {
186+
throw new RuntimeException(e);
131187
}
132188
}
133189

134-
/**
135-
* High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
136-
* Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
137-
* PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
138-
* bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
139-
* significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
140-
* benchmark run to avoid boxing/unboxing overhead.
141-
*
142-
* <pre>
143-
* Benchmark (size) Mode Cnt Score Error Units
144-
* VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
145-
* </pre>
146-
*
147-
* Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
148-
* and exact method invocation was made as shown below <b>return (int)
149-
* VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
150-
*
151-
* <pre>
152-
* Benchmark (size) Mode Cnt Score Error Units
153-
* VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
154-
* </pre>
155-
*/
156190
@Benchmark
157191
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
158-
public void dot8s() {
192+
public int simpleDot8s() {
159193
try {
160-
VectorUtil.NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB);
194+
return (int) SIMPLE_NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
161195
} catch (Throwable e) {
162196
throw new RuntimeException(e);
163197
}

lucene/core/build.gradle

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ dependencies {
2525
}
2626

2727
test {
28-
build {
29-
dependsOn ':lucene:native:build'
30-
}
28+
dependsOn ':lucene:misc:dotProductSharedLibrary'
3129
systemProperty(
3230
"java.library.path",
33-
project(":lucene:native").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
31+
project(":lucene:misc").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
3432
)
3533
}

lucene/core/src/java/module-info.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666

6767
exports org.apache.lucene.util.quantization;
6868
exports org.apache.lucene.codecs.hnsw;
69+
exports org.apache.lucene.internal.vectorization to
70+
org.apache.lucene.benchmark.jmh;
6971

7072
provides org.apache.lucene.analysis.TokenizerFactory with
7173
org.apache.lucene.analysis.standard.StandardTokenizerFactory;

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2727
import org.apache.lucene.index.SegmentReadState;
2828
import org.apache.lucene.index.SegmentWriteState;
29+
import org.apache.lucene.util.Constants;
2930

3031
/**
3132
* Format supporting vector quantization, storage, and retrieval
@@ -119,11 +120,16 @@ public Lucene99ScalarQuantizedVectorsFormat(
119120
this.bits = (byte) bits;
120121
this.confidenceInterval = confidenceInterval;
121122
this.compress = compress;
122-
FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
123-
if (scorer == DefaultFlatVectorScorer.INSTANCE) {
124-
scorer = new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
123+
if (Constants.NATIVE_DOT_PRODUCT_ENABLED == false) {
124+
this.flatVectorScorer =
125+
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
126+
} else {
127+
FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
128+
if (scorer == DefaultFlatVectorScorer.INSTANCE) {
129+
scorer = new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
130+
}
131+
this.flatVectorScorer = scorer;
125132
}
126-
this.flatVectorScorer = scorer;
127133
}
128134

129135
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ public float getScoreCorrectionConstant(int targetOrd) throws IOException {
146146
}
147147
slice.seek(((long) targetOrd * byteSize) + numBytes);
148148
slice.readFloats(scoreCorrectionConstant, 0, 1);
149-
lastOrd = targetOrd;
150149
return scoreCorrectionConstant[0];
151150
}
152151

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ private static Optional<Module> lookupVectorModule() {
188188
// add all possible callers here as FQCN:
189189
private static final Set<String> VALID_CALLERS =
190190
Set.of(
191+
"org.apache.lucene.benchmark.jmh.VectorUtilBenchmark",
191192
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
192193
"org.apache.lucene.util.VectorUtil",
193194
"org.apache.lucene.codecs.lucene101.Lucene101PostingsReader",

lucene/core/src/java/org/apache/lucene/util/Constants.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,14 @@ private static boolean is64Bit() {
100100
/** true iff we know VFMA has faster throughput than separate vmul/vadd. */
101101
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();
102102

103-
// TODO: <below condition> && Boolean.parseBoolean(getSysProp("lucene.useNativeDotProduct",
104-
// "False")
105-
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = OS_ARCH.equalsIgnoreCase("aarch64");
103+
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = enableNativeDotProduct();
104+
105+
private static boolean enableNativeDotProduct() {
106+
var armArchitecture = OS_ARCH.equalsIgnoreCase("aarch64");
107+
var enabledExplicitly = Boolean.parseBoolean(getSysProp("lucene.useNativeDotProduct", "false"));
108+
var enabledForTests = Boolean.parseBoolean(getSysProp("test.native.dotProduct", "false"));
109+
return (armArchitecture && enabledExplicitly) || enabledForTests;
110+
}
106111

107112
/** true iff we know FMA has faster throughput than separate mul/add. */
108113
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public final class VectorUtil {
5050

5151
private static final float EPSILON = 1e-4f;
5252

53-
private static final VectorUtilSupport IMPL =
53+
public static final VectorUtilSupport IMPL =
5454
VectorizationProvider.getInstance().getVectorUtilSupport();
5555

5656
private VectorUtil() {}

0 commit comments

Comments
 (0)