Skip to content

Commit d13301c

Browse files
authored
[spark] support distributed execution of vector search on spark (#8108)
Purpose: Currently, vector search operation is executed on a single node within the driver, which may lead to performance bottlenecks when dealing with large amounts of data. This issue aims to implement a distributed execution capability.
1 parent 3b639af commit d13301c

11 files changed

Lines changed: 375 additions & 24 deletions

File tree

docs/generated/core_configuration.html

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,12 @@
16501650
<td>String</td>
16511651
<td>Specifies column names that should be stored as vector type. This is used when you want to treat a ARRAY column as a VECTOR.</td>
16521652
</tr>
1653+
<tr>
1654+
<td><h5>vector-search.distribute.enabled</h5></td>
1655+
<td style="word-wrap: break-word;">false</td>
1656+
<td>Boolean</td>
1657+
<td>Whether to process distributed vector search.</td>
1658+
</tr>
16531659
<tr>
16541660
<td><h5>vector.file.format</h5></td>
16551661
<td style="word-wrap: break-word;">(none)</td>

paimon-api/src/main/java/org/apache/paimon/CoreOptions.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,6 +2597,12 @@ public InlineElement getDescription() {
25972597
+ " Default is the same as TARGET_FILE_SIZE.")
25982598
.build());
25992599

2600+
public static final ConfigOption<Boolean> VECTOR_SEARCH_DISTRIBUTE_ENABLED =
2601+
key("vector-search.distribute.enabled")
2602+
.booleanType()
2603+
.defaultValue(false)
2604+
.withDescription("Whether to process distributed vector search.");
2605+
26002606
@Immutable
26012607
public static final ConfigOption<Boolean> PK_CLUSTERING_OVERRIDE =
26022608
key("pk-clustering-override")
@@ -4077,6 +4083,10 @@ public long vectorTargetFileSize() {
40774083
.orElse(targetFileSize(false));
40784084
}
40794085

4086+
public boolean vectorSearchDistributeEnabled() {
4087+
return options.get(VECTOR_SEARCH_DISTRIBUTE_ENABLED);
4088+
}
4089+
40804090
/** Specifies the merge engine for table with primary key. */
40814091
public enum MergeEngine implements DescribedEnum {
40824092
DEDUPLICATE("deduplicate", "De-duplicate and keep the last row."),

paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.paimon.io.DataInputView;
2424
import org.apache.paimon.io.DataOutputSerializer;
2525
import org.apache.paimon.io.DataOutputView;
26+
import org.apache.paimon.utils.Preconditions;
2627
import org.apache.paimon.utils.RoaringNavigableMap64;
2728

2829
import java.io.IOException;
@@ -116,4 +117,20 @@ public GlobalIndexResult deserialize(DataInputView dataInput) throws IOException
116117

117118
return ScoredGlobalIndexResult.create(roaringNavigableMap64, scoreMap::get);
118119
}
120+
121+
public byte[] serialize(GlobalIndexResult globalIndexResult) throws IOException {
122+
DataOutputSerializer dataOutputSerializer = new DataOutputSerializer(1024);
123+
serialize(globalIndexResult, dataOutputSerializer);
124+
return dataOutputSerializer.getCopyOfBuffer();
125+
}
126+
127+
public ScoredGlobalIndexResult deserialize(byte[] data) throws IOException {
128+
DataInputDeserializer dataInputDeserializer = new DataInputDeserializer(data);
129+
GlobalIndexResult globalIndexResult = deserialize(dataInputDeserializer);
130+
Preconditions.checkArgument(
131+
globalIndexResult instanceof ScoredGlobalIndexResult,
132+
"Expected ScoredGlobalIndexResult, but got %s",
133+
globalIndexResult == null ? "null" : globalIndexResult.getClass().getName());
134+
return (ScoredGlobalIndexResult) globalIndexResult;
135+
}
119136
}

paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@
2525
import java.io.DataInputStream;
2626
import java.io.DataOutputStream;
2727
import java.io.IOException;
28+
import java.io.Serializable;
2829
import java.util.Iterator;
2930
import java.util.List;
3031
import java.util.Objects;
3132

3233
/** A compressed bitmap for 64-bit integer aggregated by tree. */
33-
public class RoaringNavigableMap64 implements Iterable<Long> {
34+
public class RoaringNavigableMap64 implements Iterable<Long>, Serializable {
35+
36+
private static final long serialVersionUID = 1L;
3437

3538
private final Roaring64NavigableMap roaring64NavigableMap;
3639

paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import javax.annotation.Nullable;
4343

4444
import java.io.IOException;
45+
import java.io.Serializable;
4546
import java.util.ArrayList;
4647
import java.util.Comparator;
4748
import java.util.List;
@@ -55,13 +56,15 @@
5556
import static org.apache.paimon.utils.Preconditions.checkNotNull;
5657

5758
/** Implementation for {@link VectorRead}. */
58-
public class VectorReadImpl implements VectorRead {
59+
public class VectorReadImpl implements VectorRead, Serializable {
5960

60-
private final FileStoreTable table;
61+
private static final long serialVersionUID = 1L;
62+
63+
protected final FileStoreTable table;
6164
private final Predicate filter;
62-
private final int limit;
63-
private final DataField vectorColumn;
64-
private final float[] vector;
65+
protected final int limit;
66+
protected final DataField vectorColumn;
67+
protected final float[] vector;
6568

6669
public VectorReadImpl(
6770
FileStoreTable table,
@@ -120,7 +123,7 @@ public GlobalIndexResult read(List<VectorSearchSplit> splits) {
120123
return result.topK(limit);
121124
}
122125

123-
private Optional<RoaringNavigableMap64> preFilter(List<VectorSearchSplit> splits) {
126+
protected Optional<RoaringNavigableMap64> preFilter(List<VectorSearchSplit> splits) {
124127
Set<IndexFileMeta> scalarIndexFiles =
125128
new TreeSet<>(Comparator.comparing(IndexFileMeta::fileName));
126129
for (VectorSearchSplit split : splits) {
@@ -139,7 +142,7 @@ private Optional<RoaringNavigableMap64> preFilter(List<VectorSearchSplit> splits
139142
}
140143
}
141144

142-
private CompletableFuture<Optional<ScoredGlobalIndexResult>> eval(
145+
protected CompletableFuture<Optional<ScoredGlobalIndexResult>> eval(
143146
GlobalIndexer globalIndexer,
144147
IndexPathFactory indexPathFactory,
145148
long rowRangeStart,

paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ public class VectorSearchBuilderImpl implements VectorSearchBuilder {
3232

3333
private static final long serialVersionUID = 1L;
3434

35-
private final FileStoreTable table;
35+
protected final FileStoreTable table;
3636

37-
private PartitionPredicate partitionFilter;
38-
private Predicate filter;
39-
private int limit;
40-
private DataField vectorColumn;
41-
private float[] vector;
37+
protected PartitionPredicate partitionFilter;
38+
protected Predicate filter;
39+
protected int limit;
40+
protected DataField vectorColumn;
41+
protected float[] vector;
4242

4343
public VectorSearchBuilderImpl(InnerTable table) {
4444
this.table = (FileStoreTable) table;
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.spark.read;
20+
21+
import org.apache.paimon.utils.SerializableFunction;
22+
23+
import org.apache.spark.api.java.JavaSparkContext;
24+
import org.apache.spark.broadcast.Broadcast;
25+
import org.apache.spark.sql.SparkSession;
26+
27+
import java.util.Collections;
28+
import java.util.List;
29+
import java.util.stream.Stream;
30+
31+
/**
32+
* Tiny wrapper around the active {@link SparkSession} that exposes RDD style {@code map} / {@code
33+
* flatMap} primitives over a Java {@link List}. Used by Paimon-on-Spark to dispatch
34+
* embarrassingly-parallel work (e.g. per-split vector search) to the cluster without forcing the
35+
* caller to depend on Spark types directly.
36+
*/
37+
public class SparkEngineContext {
38+
39+
private final JavaSparkContext jsc;
40+
41+
public SparkEngineContext() {
42+
this.jsc = JavaSparkContext.fromSparkContext(SparkSession.active().sparkContext());
43+
}
44+
45+
public <T> Broadcast<T> broadcast(T value) {
46+
return jsc.broadcast(value);
47+
}
48+
49+
public <I, O> List<O> map(List<I> data, SerializableFunction<I, O> func, int parallelism) {
50+
if (data.isEmpty()) {
51+
return Collections.emptyList();
52+
}
53+
return jsc.parallelize(data, parallelism).map(func::apply).collect();
54+
}
55+
56+
public <I, O> List<O> flatMap(
57+
List<I> data, SerializableFunction<I, Stream<O>> func, int parallelism) {
58+
if (data.isEmpty()) {
59+
return Collections.emptyList();
60+
}
61+
return jsc.parallelize(data, parallelism).flatMap(x -> func.apply(x).iterator()).collect();
62+
}
63+
}
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.spark.read;
20+
21+
import org.apache.paimon.globalindex.GlobalIndexReadThreadPool;
22+
import org.apache.paimon.globalindex.GlobalIndexResult;
23+
import org.apache.paimon.globalindex.GlobalIndexResultSerializer;
24+
import org.apache.paimon.globalindex.GlobalIndexer;
25+
import org.apache.paimon.globalindex.GlobalIndexerFactoryUtils;
26+
import org.apache.paimon.globalindex.ScoredGlobalIndexResult;
27+
import org.apache.paimon.index.IndexPathFactory;
28+
import org.apache.paimon.predicate.Predicate;
29+
import org.apache.paimon.table.FileStoreTable;
30+
import org.apache.paimon.table.source.VectorReadImpl;
31+
import org.apache.paimon.table.source.VectorSearchSplit;
32+
import org.apache.paimon.types.DataField;
33+
import org.apache.paimon.utils.InstantiationUtil;
34+
import org.apache.paimon.utils.RoaringNavigableMap64;
35+
import org.apache.paimon.utils.SerializableFunction;
36+
37+
import org.apache.spark.broadcast.Broadcast;
38+
39+
import java.io.IOException;
40+
import java.util.ArrayList;
41+
import java.util.List;
42+
import java.util.Optional;
43+
import java.util.concurrent.CompletableFuture;
44+
import java.util.concurrent.ExecutorService;
45+
46+
import static org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM;
47+
48+
/**
49+
* Spark-aware {@link VectorReadImpl} that distributes grouped vector index evaluation across the
50+
* Spark cluster instead of evaluating them with the local thread pool.
51+
*/
52+
public class SparkVectorReadImpl extends VectorReadImpl {
53+
54+
private static final long serialVersionUID = 1L;
55+
56+
public SparkVectorReadImpl(
57+
FileStoreTable table,
58+
Predicate filter,
59+
int limit,
60+
DataField vectorColumn,
61+
float[] vector) {
62+
super(table, filter, limit, vectorColumn, vector);
63+
}
64+
65+
@Override
66+
public GlobalIndexResult read(List<VectorSearchSplit> splits) {
67+
if (splits.isEmpty()) {
68+
return GlobalIndexResult.createEmpty();
69+
}
70+
71+
int parallelism =
72+
Math.max(1, table.coreOptions().toConfiguration().get(GLOBAL_INDEX_THREAD_NUM));
73+
if (splits.size() < parallelism * 2) {
74+
return super.read(splits);
75+
}
76+
77+
RoaringNavigableMap64 preFilter = preFilter(splits).orElse(null);
78+
String indexType = splits.get(0).vectorIndexFiles().get(0).indexType();
79+
List<byte[]> splitBytes = new ArrayList<>(splits.size());
80+
for (VectorSearchSplit split : splits) {
81+
try {
82+
splitBytes.add(InstantiationUtil.serializeObject(split));
83+
} catch (IOException e) {
84+
throw new RuntimeException("Failed to serialize VectorSearchSplit", e);
85+
}
86+
}
87+
List<List<byte[]>> splitGroups = splitGroups(splitBytes, parallelism);
88+
SparkEngineContext engineContext = new SparkEngineContext();
89+
Broadcast<RoaringNavigableMap64> preFilterBroadcast =
90+
preFilter == null ? null : engineContext.broadcast(preFilter);
91+
92+
SerializableFunction<List<byte[]>, byte[]> task =
93+
group -> {
94+
GlobalIndexer globalIndexer =
95+
GlobalIndexerFactoryUtils.load(indexType)
96+
.create(vectorColumn, table.coreOptions().toConfiguration());
97+
IndexPathFactory indexPathFactory =
98+
table.store().pathFactory().globalIndexFileFactory();
99+
100+
RoaringNavigableMap64 includeRowIds =
101+
preFilterBroadcast == null ? null : preFilterBroadcast.value();
102+
ExecutorService executor =
103+
GlobalIndexReadThreadPool.getExecutorService(
104+
Math.min(parallelism, group.size()));
105+
List<CompletableFuture<Optional<ScoredGlobalIndexResult>>> futures =
106+
new ArrayList<>(group.size());
107+
for (byte[] bytes : group) {
108+
VectorSearchSplit split = deserializeSplit(bytes);
109+
futures.add(
110+
eval(
111+
globalIndexer,
112+
indexPathFactory,
113+
split.rowRangeStart(),
114+
split.rowRangeEnd(),
115+
split.vectorIndexFiles(),
116+
includeRowIds,
117+
executor));
118+
}
119+
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
120+
ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty();
121+
for (CompletableFuture<Optional<ScoredGlobalIndexResult>> f : futures) {
122+
Optional<ScoredGlobalIndexResult> next = f.join();
123+
if (next.isPresent()) {
124+
result = result.or(next.get());
125+
}
126+
}
127+
result = result.topK(limit);
128+
if (result.results().isEmpty()) {
129+
return null;
130+
}
131+
try {
132+
return new GlobalIndexResultSerializer().serialize(result);
133+
} catch (IOException e) {
134+
throw new RuntimeException(
135+
"Failed to serialize ScoredGlobalIndexResult", e);
136+
}
137+
};
138+
139+
List<byte[]> remoteResults;
140+
try {
141+
remoteResults = engineContext.map(splitGroups, task, splitGroups.size());
142+
} finally {
143+
if (preFilterBroadcast != null) {
144+
preFilterBroadcast.unpersist(false);
145+
}
146+
}
147+
148+
ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty();
149+
GlobalIndexResultSerializer serializer = new GlobalIndexResultSerializer();
150+
for (byte[] bytes : remoteResults) {
151+
if (bytes != null) {
152+
try {
153+
result = result.or(serializer.deserialize(bytes));
154+
} catch (IOException e) {
155+
throw new RuntimeException("Failed to deserialize ScoredGlobalIndexResult", e);
156+
}
157+
}
158+
}
159+
return result.topK(limit);
160+
}
161+
162+
private VectorSearchSplit deserializeSplit(byte[] bytes) {
163+
try {
164+
return InstantiationUtil.deserializeObject(
165+
bytes, Thread.currentThread().getContextClassLoader());
166+
} catch (IOException | ClassNotFoundException e) {
167+
throw new RuntimeException("Failed to deserialize VectorSearchSplit", e);
168+
}
169+
}
170+
171+
private List<List<byte[]>> splitGroups(List<byte[]> splitBytes, int parallelism) {
172+
List<List<byte[]>> groups = new ArrayList<>(parallelism);
173+
int groupSize = (splitBytes.size() + parallelism - 1) / parallelism;
174+
for (int start = 0; start < splitBytes.size(); start += groupSize) {
175+
groups.add(
176+
new ArrayList<>(
177+
splitBytes.subList(
178+
start, Math.min(start + groupSize, splitBytes.size()))));
179+
}
180+
return groups;
181+
}
182+
}

0 commit comments

Comments
 (0)