Skip to content

Commit 39368cb

Browse files
wenjin272claude
andcommitted
[FLINK-38721] Extract AbstractElasticsearchVectorSearchFunction base class.
Address PR #137 review feedback from xishuaidelin: - Introduce AbstractElasticsearchVectorSearchFunction in the base module so ES7 / ES8 share the retry loop, result decoding and SearchResult type; each subclass now only supplies client initialization and the version-specific search call. - Filter hits whose source is null to avoid NPE when deserializing. - Drop redundant checkNotNull on the primitive maxRetryTimes parameter. Also apply spotless formatting drift picked up on adjacent files. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8eb8f84 commit 39368cb

8 files changed

Lines changed: 314 additions & 275 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.connector.elasticsearch.table.search;
20+
21+
import org.apache.flink.api.common.serialization.DeserializationSchema;
22+
import org.apache.flink.table.data.GenericRowData;
23+
import org.apache.flink.table.data.RowData;
24+
import org.apache.flink.table.data.utils.JoinedRowData;
25+
import org.apache.flink.table.functions.FunctionContext;
26+
import org.apache.flink.table.functions.VectorSearchFunction;
27+
import org.apache.flink.util.FlinkRuntimeException;
28+
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
import java.io.IOException;
33+
import java.util.ArrayList;
34+
import java.util.Collection;
35+
import java.util.Collections;
36+
37+
import static org.apache.flink.util.Preconditions.checkNotNull;
38+
39+
/**
40+
* Base {@link VectorSearchFunction} implementation for Elasticsearch. Shared retry loop, result
41+
* decoding and null-source filtering live here; version-specific subclasses only need to provide
42+
* the client initialization and the search call.
43+
*/
44+
public abstract class AbstractElasticsearchVectorSearchFunction extends VectorSearchFunction {
45+
private static final Logger LOG =
46+
LoggerFactory.getLogger(AbstractElasticsearchVectorSearchFunction.class);
47+
private static final long serialVersionUID = 1L;
48+
49+
protected final DeserializationSchema<RowData> deserializationSchema;
50+
protected final String index;
51+
protected final String searchColumn;
52+
protected final String[] producedNames;
53+
protected final int maxRetryTimes;
54+
55+
protected AbstractElasticsearchVectorSearchFunction(
56+
DeserializationSchema<RowData> deserializationSchema,
57+
int maxRetryTimes,
58+
String index,
59+
String searchColumn,
60+
String[] producedNames) {
61+
this.deserializationSchema =
62+
checkNotNull(deserializationSchema, "No DeserializationSchema supplied.");
63+
this.producedNames = checkNotNull(producedNames, "No fieldNames supplied.");
64+
this.maxRetryTimes = maxRetryTimes;
65+
this.index = index;
66+
this.searchColumn = searchColumn;
67+
}
68+
69+
@Override
70+
public void open(FunctionContext context) throws Exception {
71+
doOpen(context);
72+
deserializationSchema.open(null);
73+
}
74+
75+
@Override
76+
public Collection<RowData> vectorSearch(int topK, RowData features) throws IOException {
77+
for (int retry = 0; retry <= maxRetryTimes; retry++) {
78+
try {
79+
SearchResult[] results = doSearch(topK, features);
80+
if (results.length > 0) {
81+
ArrayList<RowData> rows = new ArrayList<>(results.length);
82+
for (SearchResult result : results) {
83+
if (result.source == null) {
84+
continue;
85+
}
86+
RowData row = parseSearchResult(result.source);
87+
if (row == null) {
88+
continue;
89+
}
90+
GenericRowData scoreData = new GenericRowData(1);
91+
scoreData.setField(0, result.score);
92+
rows.add(new JoinedRowData(row, scoreData));
93+
}
94+
rows.trimToSize();
95+
return rows;
96+
}
97+
} catch (IOException e) {
98+
LOG.error(String.format("Elasticsearch search error, retry times = %d", retry), e);
99+
if (retry >= maxRetryTimes) {
100+
throw new FlinkRuntimeException("Execution of Elasticsearch search failed.", e);
101+
}
102+
try {
103+
Thread.sleep(1000L * retry);
104+
} catch (InterruptedException e1) {
105+
LOG.warn(
106+
"Interrupted while waiting to retry failed elasticsearch search, aborting");
107+
throw new FlinkRuntimeException(e1);
108+
}
109+
}
110+
}
111+
return Collections.emptyList();
112+
}
113+
114+
/** Version-specific initialization (e.g., creating the underlying Elasticsearch client). */
115+
protected abstract void doOpen(FunctionContext context) throws Exception;
116+
117+
/** Execute a single vector search call and return raw results, excluding nothing. */
118+
protected abstract SearchResult[] doSearch(int topK, RowData features) throws IOException;
119+
120+
private RowData parseSearchResult(String result) {
121+
try {
122+
return deserializationSchema.deserialize(result.getBytes());
123+
} catch (IOException e) {
124+
LOG.error("Deserialize search hit failed: " + e.getMessage());
125+
return null;
126+
}
127+
}
128+
129+
/** One hit from Elasticsearch — raw JSON source plus score. */
130+
protected static class SearchResult {
131+
final String source;
132+
final Double score;
133+
134+
public SearchResult(String source, Double score) {
135+
this.source = source;
136+
this.score = score;
137+
}
138+
}
139+
}

flink-connector-elasticsearch-base/src/main/java/org/apache/flink/connector/elasticsearch/table/search/SearchMetric.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.apache.flink.connector.elasticsearch.table.search;
22

3-
/**Metric for vector search.*/
3+
/** Metric for vector search. */
44
public enum SearchMetric {
55
COSINE_SIMILARITY("cosineSimilarity"),
66
L1NORM("l1norm"),

flink-connector-elasticsearch-base/src/test/java/org/apache/flink/connector/elasticsearch/ElasticsearchUtil.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ public static ElasticsearchContainer createElasticsearchContainer(
6464
logLevel = "OFF";
6565
}
6666

67-
ElasticsearchContainer container = new ElasticsearchContainer(DockerImageName.parse(dockerImageVersion))
68-
.withEnv("ES_JAVA_OPTS", "-Xms2g -Xmx2g")
69-
.withEnv("logger.org.elasticsearch", logLevel)
70-
.withLogConsumer(new Slf4jLogConsumer(log));
67+
ElasticsearchContainer container =
68+
new ElasticsearchContainer(DockerImageName.parse(dockerImageVersion))
69+
.withEnv("ES_JAVA_OPTS", "-Xms2g -Xmx2g")
70+
.withEnv("logger.org.elasticsearch", logLevel)
71+
.withLogConsumer(new Slf4jLogConsumer(log));
7172

7273
container.setWaitStrategy(
7374
Wait.defaultWaitStrategy().withStartupTimeout(Duration.ofMinutes(1)));
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
package org.apache.flink.connector.elasticsearch.table.search;
22

33
import org.apache.flink.api.common.serialization.DeserializationSchema;
4-
import org.apache.flink.api.java.tuple.Tuple2;
54
import org.apache.flink.connector.elasticsearch.ElasticsearchApiCallBridge;
65
import org.apache.flink.connector.elasticsearch.NetworkClientConfig;
7-
import org.apache.flink.table.data.GenericRowData;
86
import org.apache.flink.table.data.RowData;
9-
import org.apache.flink.table.data.utils.JoinedRowData;
107
import org.apache.flink.table.functions.FunctionContext;
118
import org.apache.flink.table.functions.VectorSearchFunction;
12-
import org.apache.flink.util.FlinkRuntimeException;
139

1410
import org.apache.http.HttpHost;
1511
import org.elasticsearch.action.search.SearchRequest;
@@ -22,42 +18,29 @@
2218
import org.elasticsearch.script.ScriptType;
2319
import org.elasticsearch.search.SearchHit;
2420
import org.elasticsearch.search.builder.SearchSourceBuilder;
25-
import org.slf4j.Logger;
26-
import org.slf4j.LoggerFactory;
2721

2822
import java.io.IOException;
29-
import java.util.ArrayList;
30-
import java.util.Collection;
3123
import java.util.Collections;
3224
import java.util.List;
3325
import java.util.Map;
3426
import java.util.stream.Stream;
3527

3628
import static org.apache.flink.util.Preconditions.checkNotNull;
3729

38-
/** The {@link VectorSearchFunction} implementation for Elasticsearch. */
39-
public class ElasticsearchRowDataVectorSearchFunction extends VectorSearchFunction {
40-
private static final Logger LOG =
41-
LoggerFactory.getLogger(ElasticsearchRowDataVectorSearchFunction.class);
30+
/** The {@link VectorSearchFunction} implementation for Elasticsearch 7. */
31+
public class ElasticsearchRowDataVectorSearchFunction
32+
extends AbstractElasticsearchVectorSearchFunction {
4233
private static final long serialVersionUID = 1L;
4334
private static final String QUERY_VECTOR = "query_vector";
4435

45-
private final DeserializationSchema<RowData> deserializationSchema;
46-
47-
private final String index;
48-
49-
private final String[] producedNames;
50-
private final int maxRetryTimes;
51-
private final SearchMetric searchMetric;
52-
private SearchRequest searchRequest;
53-
private SearchSourceBuilder searchSourceBuilder;
54-
5536
private final ElasticsearchApiCallBridge<RestHighLevelClient> callBridge;
5637
private final NetworkClientConfig networkClientConfig;
5738
private final List<HttpHost> hosts;
5839
private final String scriptScore;
5940

6041
private transient RestHighLevelClient client;
42+
private transient SearchRequest searchRequest;
43+
private transient SearchSourceBuilder searchSourceBuilder;
6144

6245
public ElasticsearchRowDataVectorSearchFunction(
6346
DeserializationSchema<RowData> deserializationSchema,
@@ -69,121 +52,47 @@ public ElasticsearchRowDataVectorSearchFunction(
6952
List<HttpHost> hosts,
7053
NetworkClientConfig networkClientConfig,
7154
ElasticsearchApiCallBridge<RestHighLevelClient> callBridge) {
72-
73-
checkNotNull(deserializationSchema, "No DeserializationSchema supplied.");
74-
checkNotNull(maxRetryTimes, "No maxRetryTimes supplied.");
75-
checkNotNull(producedNames, "No fieldNames supplied.");
76-
checkNotNull(hosts, "No hosts supplied.");
77-
checkNotNull(networkClientConfig, "No networkClientConfig supplied.");
78-
checkNotNull(callBridge, "No ElasticsearchApiCallBridge supplied.");
79-
80-
this.deserializationSchema = deserializationSchema;
81-
this.maxRetryTimes = maxRetryTimes;
82-
this.searchMetric = searchMetric;
83-
this.index = index;
84-
this.producedNames = producedNames;
85-
86-
this.networkClientConfig = networkClientConfig;
87-
this.hosts = hosts;
88-
this.callBridge = callBridge;
55+
super(deserializationSchema, maxRetryTimes, index, searchColumn, producedNames);
56+
this.networkClientConfig =
57+
checkNotNull(networkClientConfig, "No networkClientConfig supplied.");
58+
this.hosts = checkNotNull(hosts, "No hosts supplied.");
59+
this.callBridge = checkNotNull(callBridge, "No ElasticsearchApiCallBridge supplied.");
8960
this.scriptScore =
9061
String.format(
9162
"%s(params.%s, '%s') + 1.0",
9263
searchMetric.toString(), QUERY_VECTOR, searchColumn);
9364
}
9465

9566
@Override
96-
public void open(FunctionContext context) throws Exception {
67+
protected void doOpen(FunctionContext context) {
9768
this.client = callBridge.createClient(networkClientConfig, hosts);
9869

99-
// Set searchRequest in open method in case of amount of calling in eval method when every
100-
// record comes.
70+
// Reuse searchRequest / searchSourceBuilder across invocations to avoid rebuilding them
71+
// per record.
10172
this.searchRequest = new SearchRequest(index);
102-
searchSourceBuilder = new SearchSourceBuilder();
103-
searchSourceBuilder.fetchSource(producedNames, null);
104-
deserializationSchema.open(null);
73+
this.searchSourceBuilder = new SearchSourceBuilder();
74+
this.searchSourceBuilder.fetchSource(producedNames, null);
10575
}
10676

10777
@Override
108-
public Collection<RowData> vectorSearch(int topK, RowData features) throws IOException {
78+
protected SearchResult[] doSearch(int topK, RowData features) throws IOException {
10979
// Elasticsearch 7.x doesn't support ANN, we use script score to achieve exact matching.
11080
Map<String, Object> params =
11181
Collections.singletonMap(QUERY_VECTOR, features.getArray(0).toFloatArray());
11282

11383
Script script = new Script(ScriptType.INLINE, "painless", scriptScore, params);
114-
11584
ScriptScoreQueryBuilder scriptScoreQuery =
11685
new ScriptScoreQueryBuilder(new MatchAllQueryBuilder(), script);
11786

11887
searchSourceBuilder.query(scriptScoreQuery).size(topK);
119-
12088
searchRequest.source(searchSourceBuilder);
12189

122-
for (int retry = 0; retry <= maxRetryTimes; retry++) {
123-
try {
124-
ArrayList<RowData> rows = new ArrayList<>();
125-
Tuple2<String, SearchResult[]> searchResponse = search(client, searchRequest);
126-
127-
if (searchResponse.f1.length > 0) {
128-
for (SearchResult result : searchResponse.f1) {
129-
String source = result.source;
130-
RowData row = parseSearchResult(source);
131-
GenericRowData scoreData = new GenericRowData(1);
132-
scoreData.setField(0, Double.valueOf(result.score));
133-
if (row != null) {
134-
rows.add(new JoinedRowData(row, scoreData));
135-
}
136-
}
137-
rows.trimToSize();
138-
return rows;
139-
}
140-
} catch (IOException e) {
141-
LOG.error(String.format("Elasticsearch search error, retry times = %d", retry), e);
142-
if (retry >= maxRetryTimes) {
143-
throw new FlinkRuntimeException("Execution of Elasticsearch search failed.", e);
144-
}
145-
try {
146-
Thread.sleep(1000L * retry);
147-
} catch (InterruptedException e1) {
148-
LOG.warn(
149-
"Interrupted while waiting to retry failed elasticsearch search, aborting");
150-
throw new FlinkRuntimeException(e1);
151-
}
152-
}
153-
}
154-
return Collections.emptyList();
155-
}
156-
157-
private RowData parseSearchResult(String result) {
158-
RowData row = null;
159-
try {
160-
row = deserializationSchema.deserialize(result.getBytes());
161-
} catch (IOException e) {
162-
LOG.error("Deserialize search hit failed: " + e.getMessage());
163-
}
164-
165-
return row;
166-
}
167-
168-
private Tuple2<String, SearchResult[]> search(
169-
RestHighLevelClient client, SearchRequest searchRequest) throws IOException {
17090
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
17191
SearchHit[] searchHits = searchResponse.getHits().getHits();
17292

173-
return new Tuple2<>(
174-
searchResponse.getScrollId(),
175-
Stream.of(searchHits)
176-
.map(hit -> new SearchResult(hit.getSourceAsString(), hit.getScore()))
177-
.toArray(SearchResult[]::new));
178-
}
179-
180-
private static class SearchResult {
181-
private final String source;
182-
private final Float score;
183-
184-
public SearchResult(String source, Float score) {
185-
this.source = source;
186-
this.score = score;
187-
}
93+
return Stream.of(searchHits)
94+
.filter(hit -> hit.getSourceAsString() != null)
95+
.map(hit -> new SearchResult(hit.getSourceAsString(), (double) hit.getScore()))
96+
.toArray(SearchResult[]::new);
18897
}
18998
}

flink-connector-elasticsearch8/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch8DynamicSource.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ private SerializableSupplier<SSLContext> buildSslContextSupplier() {
327327

328328
@Override
329329
public DynamicTableSource copy() {
330-
return new Elasticsearch8DynamicSource(
331-
format, config, physicalRowDataType, summaryString);
330+
return new Elasticsearch8DynamicSource(format, config, physicalRowDataType, summaryString);
332331
}
333332

334333
@Override

flink-connector-elasticsearch8/src/main/java/org/apache/flink/connector/elasticsearch/table/Elasticsearch8DynamicTableFactory.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ public DynamicTableSource createDynamicTableSource(Context context) {
9292
validateConfiguration(config);
9393

9494
return new Elasticsearch8DynamicSource(
95-
format,
96-
config,
97-
context.getPhysicalRowDataType(),
98-
"Elasticsearch-8");
95+
format, config, context.getPhysicalRowDataType(), "Elasticsearch-8");
9996
}
10097

10198
Elasticsearch8Configuration getConfiguration(FactoryUtil.TableFactoryHelper helper) {

0 commit comments

Comments
 (0)