-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathPatchTableEmbeddingIT.java
More file actions
254 lines (219 loc) · 11 KB
/
PatchTableEmbeddingIT.java
File metadata and controls
254 lines (219 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
package org.openmetadata.it.tests;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import es.co.elastic.clients.transport.rest5_client.low_level.Request;
import es.co.elastic.clients.transport.rest5_client.low_level.Response;
import es.co.elastic.clients.transport.rest5_client.low_level.Rest5Client;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.openmetadata.it.bootstrap.TestSuiteBootstrap;
import org.openmetadata.it.factories.DatabaseSchemaTestFactory;
import org.openmetadata.it.factories.DatabaseServiceTestFactory;
import org.openmetadata.it.util.SdkClients;
import org.openmetadata.it.util.TestNamespace;
import org.openmetadata.it.util.TestNamespaceExtension;
import org.openmetadata.schema.api.data.CreateTable;
import org.openmetadata.schema.entity.data.DatabaseSchema;
import org.openmetadata.schema.entity.data.Table;
import org.openmetadata.schema.entity.services.DatabaseService;
import org.openmetadata.schema.utils.JsonUtils;
import org.openmetadata.sdk.client.OpenMetadataClient;
import org.openmetadata.sdk.fluent.builders.ColumnBuilder;
import org.openmetadata.search.IndexMapping;
import org.openmetadata.service.Entity;
import org.openmetadata.service.events.lifecycle.EntityLifecycleEventDispatcher;
import org.openmetadata.service.search.SearchRepository;
import org.openmetadata.service.search.indexes.SearchIndex;
import org.openmetadata.service.search.vector.VectorIndexService;
import org.openmetadata.service.search.vector.client.EmbeddingClient;
@Execution(ExecutionMode.CONCURRENT)
@ExtendWith(TestNamespaceExtension.class)
public class PatchTableEmbeddingIT {
private static final String KNN_TEST_INDEX = "test_knn_embedding_index";
private static final ObjectMapper MAPPER = new ObjectMapper();
@Test
void testPatchTableDescriptionUpdatesEmbeddingForSemanticSearch(TestNamespace ns)
throws Exception {
Assumptions.assumeTrue(
"opensearch".equalsIgnoreCase(System.getProperty("searchType", "elasticsearch")),
"Vector embedding tests require OpenSearch");
SearchRepository searchRepo = Entity.getSearchRepository();
TestSuiteBootstrap.withNaturalLanguageSearch(searchRepo.getSearchConfiguration());
EntityLifecycleEventDispatcher.getInstance().unregisterHandler("VectorEmbeddingHandler");
searchRepo.initializeVectorSearchService();
Assumptions.assumeTrue(
searchRepo.isVectorEmbeddingEnabled(), "Vector embedding could not be initialized");
try {
runEmbeddingTest(ns, searchRepo);
} finally {
searchRepo.getSearchConfiguration().setNaturalLanguageSearch(null);
}
}
private void runEmbeddingTest(TestNamespace ns, SearchRepository searchRepo) throws Exception {
OpenMetadataClient client = SdkClients.adminClient();
VectorIndexService vectorService = searchRepo.getVectorIndexService();
String entityIndexName = resolveTableIndexName(searchRepo);
DatabaseService service = DatabaseServiceTestFactory.createPostgres(ns);
DatabaseSchema schema = DatabaseSchemaTestFactory.createSimple(ns, service);
CreateTable createRequest = new CreateTable();
createRequest.setName(ns.prefix("emb_patch"));
createRequest.setDatabaseSchema(schema.getFullyQualifiedName());
createRequest.setDescription("Initial description about sales data processing");
createRequest.setColumns(
List.of(
ColumnBuilder.of("id", "BIGINT").primaryKey().notNull().build(),
ColumnBuilder.of("amount", "DOUBLE").build()));
Table table = client.tables().create(createRequest);
String tableId = table.getId().toString();
try (Rest5Client searchClient = TestSuiteBootstrap.createSearchClient()) {
// Index the entity doc synchronously. The async SearchIndexHandler may be
// suspended during reindex, so we bypass it and create the doc directly.
indexEntityDoc(searchRepo, table, entityIndexName);
// Generate initial embedding synchronously — no polling needed
vectorService.updateEntityEmbedding(table, entityIndexName);
String initialFingerprint =
getFieldFromDoc(searchClient, entityIndexName, tableId, "fingerprint");
assertNotNull(initialFingerprint, "Initial fingerprint should exist after sync embedding");
// Patch description and re-generate embedding synchronously
table.setDescription("Revenue metrics for quarterly financial reporting analysis");
Table updated = client.tables().update(tableId, table);
vectorService.updateEntityEmbedding(updated, entityIndexName);
String updatedFingerprint =
getFieldFromDoc(searchClient, entityIndexName, tableId, "fingerprint");
assertNotNull(updatedFingerprint, "Updated fingerprint should exist");
assertNotEquals(
initialFingerprint,
updatedFingerprint,
"Fingerprint should change after description update");
String textToLLMContext =
getFieldFromDoc(searchClient, entityIndexName, tableId, "textToLLMContext");
assertTrue(
textToLLMContext.contains("Revenue metrics"),
"textToLLMContext should reflect the patched description");
String embeddingJson = getFieldFromDoc(searchClient, entityIndexName, tableId, "embedding");
assertNotNull(embeddingJson, "Embedding vector should exist after PATCH");
List<String> knnResults =
verifyKnnSearchWithDedicatedIndex(searchClient, tableId, embeddingJson);
assertTrue(
knnResults.contains(tableId),
"Patched table should be found via KNN search for its new description");
}
}
private String resolveTableIndexName(SearchRepository searchRepo) {
IndexMapping mapping = searchRepo.getIndexMapping(Entity.TABLE);
return mapping.getIndexName(searchRepo.getClusterAlias());
}
/**
* Creates a temporary knn_vector index, indexes the entity's embedding, runs a KNN query against
* it, and cleans up. This avoids modifying the shared table search index while still validating
* that the generated embedding produces correct KNN search results.
*/
private List<String> verifyKnnSearchWithDedicatedIndex(
Rest5Client searchClient, String tableId, String embeddingJson) throws Exception {
int dimension = Entity.getSearchRepository().getEmbeddingClient().getDimension();
try {
createKnnIndex(searchClient, dimension);
indexEmbeddingDocument(searchClient, tableId, embeddingJson);
refreshKnnIndex(searchClient);
return executeKnnSearch(searchClient, 10);
} finally {
deleteKnnIndex(searchClient);
}
}
private void createKnnIndex(Rest5Client searchClient, int dimension) throws Exception {
String mapping =
String.format(
"{\"settings\":{\"index\":{\"knn\":true,\"number_of_shards\":1,"
+ "\"number_of_replicas\":0}},"
+ "\"mappings\":{\"properties\":{"
+ "\"embedding\":{\"type\":\"knn_vector\",\"dimension\":%d,"
+ "\"method\":{\"name\":\"hnsw\",\"engine\":\"lucene\","
+ "\"space_type\":\"cosinesimil\"}},"
+ "\"entityId\":{\"type\":\"keyword\"},"
+ "\"deleted\":{\"type\":\"boolean\"}}}}",
dimension);
Request request = new Request("PUT", "/" + KNN_TEST_INDEX);
request.setJsonEntity(mapping);
searchClient.performRequest(request);
}
private void indexEmbeddingDocument(
Rest5Client searchClient, String tableId, String embeddingJson) throws Exception {
String doc =
String.format(
"{\"embedding\":%s,\"entityId\":\"%s\",\"deleted\":false}", embeddingJson, tableId);
Request request = new Request("PUT", "/" + KNN_TEST_INDEX + "/_doc/" + tableId);
request.setJsonEntity(doc);
searchClient.performRequest(request);
}
private void refreshKnnIndex(Rest5Client searchClient) throws Exception {
searchClient.performRequest(new Request("POST", "/" + KNN_TEST_INDEX + "/_refresh"));
}
private List<String> executeKnnSearch(Rest5Client searchClient, int size) throws Exception {
EmbeddingClient embeddingClient = Entity.getSearchRepository().getEmbeddingClient();
float[] queryVector = embeddingClient.embed("quarterly financial revenue reporting");
String vectorStr = Arrays.toString(queryVector);
String knnQuery =
String.format(
"{\"size\":%d,\"_source\":[\"entityId\"],"
+ "\"query\":{\"knn\":{\"embedding\":{\"vector\":%s,\"k\":%d,"
+ "\"filter\":{\"bool\":{\"must\":[{\"term\":{\"deleted\":false}}]}}}}}}",
size, vectorStr, size);
Request request = new Request("POST", "/" + KNN_TEST_INDEX + "/_search");
request.setJsonEntity(knnQuery);
Response response = searchClient.performRequest(request);
String body =
new String(response.getEntity().getContent().readAllBytes(), StandardCharsets.UTF_8);
JsonNode root = MAPPER.readTree(body);
JsonNode hits = root.path("hits").path("hits");
List<String> resultIds = new ArrayList<>();
for (JsonNode hit : hits) {
String entityId = hit.path("_source").path("entityId").asText(null);
if (entityId != null) {
resultIds.add(entityId);
}
}
return resultIds;
}
private void deleteKnnIndex(Rest5Client searchClient) {
try {
searchClient.performRequest(new Request("DELETE", "/" + KNN_TEST_INDEX));
} catch (Exception e) {
// Best-effort cleanup
}
}
private void indexEntityDoc(SearchRepository searchRepo, Table table, String indexName)
throws Exception {
SearchIndex index = searchRepo.getSearchIndexFactory().buildIndex(Entity.TABLE, table);
String doc = JsonUtils.pojoToJson(index.buildSearchIndexDoc());
searchRepo.getSearchClient().createEntity(indexName, table.getId().toString(), doc);
}
/** Uses GET _doc API which reads from the translog and is always real-time. */
private String getFieldFromDoc(
Rest5Client searchClient, String indexName, String entityId, String field) throws Exception {
Request request =
new Request(
"GET", String.format("/%s/_doc/%s?_source_includes=%s", indexName, entityId, field));
Response response = searchClient.performRequest(request);
String body =
new String(response.getEntity().getContent().readAllBytes(), StandardCharsets.UTF_8);
JsonNode root = MAPPER.readTree(body);
if (!root.path("found").asBoolean(false)) {
return null;
}
JsonNode fieldValue = root.path("_source").path(field);
if (fieldValue.isMissingNode() || fieldValue.isNull()) {
return null;
}
return fieldValue.isTextual() ? fieldValue.asText() : fieldValue.toString();
}
}