Skip to content

Commit 736b411

Browse files
wenjin272claude
andcommitted
[FLINK-38721] Support vector search for es8 connector.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 18e1cf7 commit 736b411

16 files changed

Lines changed: 1095 additions & 195 deletions

File tree

flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/ElasticsearchDynamicSource.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
public class ElasticsearchDynamicSource implements LookupTableSource, SupportsProjectionPushDown {
2727
protected final DecodingFormat<DeserializationSchema<RowData>> format;
2828
protected final ElasticsearchConfiguration config;
29-
private final int lookupMaxRetryTimes;
29+
protected final int maxRetryTimes;
3030
private final LookupCache lookupCache;
3131
private final String docType;
3232
private final String summaryString;
@@ -37,15 +37,15 @@ public ElasticsearchDynamicSource(
3737
DecodingFormat<DeserializationSchema<RowData>> format,
3838
ElasticsearchConfiguration config,
3939
DataType physicalRowDataType,
40-
int lookupMaxRetryTimes,
40+
int maxRetryTimes,
4141
String summaryString,
4242
ElasticsearchApiCallBridge<?> apiCallBridge,
4343
@Nullable LookupCache lookupCache,
4444
@Nullable String docType) {
4545
this.format = format;
4646
this.config = config;
4747
this.physicalRowDataType = physicalRowDataType;
48-
this.lookupMaxRetryTimes = lookupMaxRetryTimes;
48+
this.maxRetryTimes = maxRetryTimes;
4949
this.summaryString = summaryString;
5050
this.apiCallBridge = apiCallBridge;
5151
this.lookupCache = lookupCache;
@@ -68,7 +68,7 @@ public LookupRuntimeProvider getLookupRuntimeProvider(LookupContext context) {
6868
ElasticsearchRowDataLookupFunction<?> lookupFunction =
6969
new ElasticsearchRowDataLookupFunction<>(
7070
this.format.createRuntimeDecoder(context, physicalRowDataType),
71-
lookupMaxRetryTimes,
71+
maxRetryTimes,
7272
config.getIndex(),
7373
docType,
7474
DataType.getFieldNames(physicalRowDataType).toArray(new String[0]),
@@ -123,7 +123,7 @@ public DynamicTableSource copy() {
123123
format,
124124
config,
125125
physicalRowDataType,
126-
lookupMaxRetryTimes,
126+
maxRetryTimes,
127127
summaryString,
128128
apiCallBridge,
129129
lookupCache,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.connector.elasticsearch.table.search;
20+
21+
import org.apache.flink.api.common.serialization.DeserializationSchema;
22+
import org.apache.flink.table.data.GenericRowData;
23+
import org.apache.flink.table.data.RowData;
24+
import org.apache.flink.table.data.utils.JoinedRowData;
25+
import org.apache.flink.table.functions.FunctionContext;
26+
import org.apache.flink.table.functions.VectorSearchFunction;
27+
import org.apache.flink.util.FlinkRuntimeException;
28+
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
import java.io.IOException;
33+
import java.util.ArrayList;
34+
import java.util.Collection;
35+
import java.util.Collections;
36+
37+
import static org.apache.flink.util.Preconditions.checkNotNull;
38+
39+
/**
40+
* Base {@link VectorSearchFunction} implementation for Elasticsearch. Shared retry loop, result
41+
* decoding and null-source filtering live here; version-specific subclasses only need to provide
42+
* the client initialization and the search call.
43+
*/
44+
public abstract class AbstractElasticsearchVectorSearchFunction extends VectorSearchFunction {
45+
private static final Logger LOG =
46+
LoggerFactory.getLogger(AbstractElasticsearchVectorSearchFunction.class);
47+
private static final long serialVersionUID = 1L;
48+
49+
protected final DeserializationSchema<RowData> deserializationSchema;
50+
protected final String index;
51+
protected final String searchColumn;
52+
protected final String[] producedNames;
53+
protected final int maxRetryTimes;
54+
55+
protected AbstractElasticsearchVectorSearchFunction(
56+
DeserializationSchema<RowData> deserializationSchema,
57+
int maxRetryTimes,
58+
String index,
59+
String searchColumn,
60+
String[] producedNames) {
61+
this.deserializationSchema =
62+
checkNotNull(deserializationSchema, "No DeserializationSchema supplied.");
63+
this.producedNames = checkNotNull(producedNames, "No fieldNames supplied.");
64+
this.maxRetryTimes = maxRetryTimes;
65+
this.index = index;
66+
this.searchColumn = searchColumn;
67+
}
68+
69+
@Override
70+
public void open(FunctionContext context) throws Exception {
71+
doOpen(context);
72+
deserializationSchema.open(null);
73+
}
74+
75+
@Override
76+
public Collection<RowData> vectorSearch(int topK, RowData features) throws IOException {
77+
for (int retry = 0; retry <= maxRetryTimes; retry++) {
78+
try {
79+
SearchResult[] results = doSearch(topK, features);
80+
if (results.length > 0) {
81+
ArrayList<RowData> rows = new ArrayList<>(results.length);
82+
for (SearchResult result : results) {
83+
if (result.source == null) {
84+
continue;
85+
}
86+
RowData row = parseSearchResult(result.source);
87+
if (row == null) {
88+
continue;
89+
}
90+
GenericRowData scoreData = new GenericRowData(1);
91+
scoreData.setField(0, result.score);
92+
rows.add(new JoinedRowData(row, scoreData));
93+
}
94+
rows.trimToSize();
95+
return rows;
96+
}
97+
} catch (IOException e) {
98+
LOG.error(String.format("Elasticsearch search error, retry times = %d", retry), e);
99+
if (retry >= maxRetryTimes) {
100+
throw new FlinkRuntimeException("Execution of Elasticsearch search failed.", e);
101+
}
102+
try {
103+
Thread.sleep(1000L * retry);
104+
} catch (InterruptedException e1) {
105+
LOG.warn(
106+
"Interrupted while waiting to retry failed elasticsearch search, aborting");
107+
throw new FlinkRuntimeException(e1);
108+
}
109+
}
110+
}
111+
return Collections.emptyList();
112+
}
113+
114+
/** Version-specific initialization (e.g., creating the underlying Elasticsearch client). */
115+
protected abstract void doOpen(FunctionContext context) throws Exception;
116+
117+
/** Execute a single vector search call and return raw results, excluding nothing. */
118+
protected abstract SearchResult[] doSearch(int topK, RowData features) throws IOException;
119+
120+
private RowData parseSearchResult(String result) {
121+
try {
122+
return deserializationSchema.deserialize(result.getBytes());
123+
} catch (IOException e) {
124+
LOG.error("Deserialize search hit failed: " + e.getMessage());
125+
return null;
126+
}
127+
}
128+
129+
/** One hit from Elasticsearch — raw JSON source plus score. */
130+
protected static class SearchResult {
131+
final String source;
132+
final Double score;
133+
134+
public SearchResult(String source, Double score) {
135+
this.source = source;
136+
this.score = score;
137+
}
138+
}
139+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.apache.flink.connector.elasticsearch.table.search;
2+
3+
/** Metric for vector search. */
4+
public enum SearchMetric {
5+
COSINE_SIMILARITY("cosineSimilarity"),
6+
L1NORM("l1norm"),
7+
L2NORM("l2norm"),
8+
HAMMING("hamming"),
9+
DOT_PRODUCT("dotProduct");
10+
11+
private final String name;
12+
13+
SearchMetric(String name) {
14+
this.name = name;
15+
}
16+
17+
@Override
18+
public String toString() {
19+
return name;
20+
}
21+
}

flink-connector-elasticsearch-base/src/test/java/org/apache/flink/connector/elasticsearch/ElasticsearchUtil.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626

2727
import org.slf4j.Logger;
2828
import org.testcontainers.containers.output.Slf4jLogConsumer;
29+
import org.testcontainers.containers.wait.strategy.Wait;
2930
import org.testcontainers.elasticsearch.ElasticsearchContainer;
3031
import org.testcontainers.utility.DockerImageName;
3132

33+
import java.time.Duration;
3234
import java.util.Optional;
3335

3436
/** Collection of utility methods for Elasticsearch tests. */
@@ -62,10 +64,16 @@ public static ElasticsearchContainer createElasticsearchContainer(
6264
logLevel = "OFF";
6365
}
6466

65-
return new ElasticsearchContainer(DockerImageName.parse(dockerImageVersion))
66-
.withEnv("ES_JAVA_OPTS", "-Xms2g -Xmx2g")
67-
.withEnv("logger.org.elasticsearch", logLevel)
68-
.withLogConsumer(new Slf4jLogConsumer(log));
67+
ElasticsearchContainer container =
68+
new ElasticsearchContainer(DockerImageName.parse(dockerImageVersion))
69+
.withEnv("ES_JAVA_OPTS", "-Xms2g -Xmx2g")
70+
.withEnv("logger.org.elasticsearch", logLevel)
71+
.withLogConsumer(new Slf4jLogConsumer(log));
72+
73+
container.setWaitStrategy(
74+
Wait.defaultWaitStrategy().withStartupTimeout(Duration.ofMinutes(1)));
75+
76+
return container;
6977
}
7078

7179
/** A mock {@link DynamicTableSink.Context} for Elasticsearch tests. */

0 commit comments

Comments
 (0)