Skip to content

Commit 9819769

Browse files
authored
Fix a bug of Function.multiAnalyzerParams (#1538)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent f250f86 commit 9819769

3 files changed

Lines changed: 308 additions & 9 deletions

File tree

sdk-core/src/main/java/io/milvus/common/clientenum/FunctionType.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,35 @@
2222
import lombok.Getter;
2323

2424
public enum FunctionType {
25-
UNKNOWN(0),
25+
UNKNOWN("Unknown", 0), // in milvus-proto, the name is "Unknown"
2626
BM25(1),
2727
;
2828

29+
private final String name;
2930
@Getter
3031
private final int code;
3132

32-
FunctionType(int i) {
33-
code = i;
33+
FunctionType(){
34+
this.name = this.name();
35+
this.code = this.ordinal();
36+
}
37+
38+
FunctionType(int code){
39+
this.name = this.name();
40+
this.code = code;
41+
}
42+
43+
FunctionType(String name, int code){
44+
this.name = name;
45+
this.code = code;
46+
}
47+
48+
public static FunctionType fromName(String name) {
49+
for (FunctionType type : FunctionType.values()) {
50+
if (type.name().equals(name)) {
51+
return type;
52+
}
53+
}
54+
return null;
3455
}
3556
}

sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public static CreateCollectionReq.CollectionSchema convertFromGrpcCollectionSche
161161
return collectionSchema;
162162
}
163163

164-
private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldSchema fieldSchema) {
164+
public static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldSchema fieldSchema) {
165165
CreateCollectionReq.FieldSchema schema = CreateCollectionReq.FieldSchema.builder()
166166
.name(fieldSchema.getName())
167167
.description(fieldSchema.getDescription())
@@ -191,6 +191,9 @@ private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldS
191191
} else if(keyValuePair.getKey().equals("analyzer_params")){
192192
Map<String, Object> params = JsonUtils.fromJson(keyValuePair.getValue(), new TypeToken<Map<String, Object>>() {}.getType());
193193
schema.setAnalyzerParams(params);
194+
} else if(keyValuePair.getKey().equals("multi_analyzer_params")){
195+
Map<String, Object> params = JsonUtils.fromJson(keyValuePair.getValue(), new TypeToken<Map<String, Object>>() {}.getType());
196+
schema.setMultiAnalyzerParams(params);
194197
}
195198
} catch (Exception e) {
196199
/**
@@ -207,14 +210,13 @@ private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldS
207210
}
208211

209212
public static CreateCollectionReq.Function convertFromGrpcFunction(FunctionSchema functionSchema) {
210-
CreateCollectionReq.Function function = CreateCollectionReq.Function.builder()
213+
CreateCollectionReq.Function.FunctionBuilder builder = CreateCollectionReq.Function.builder()
211214
.name(functionSchema.getName())
212215
.description(functionSchema.getDescription())
213-
.functionType(io.milvus.common.clientenum.FunctionType.valueOf(functionSchema.getType().name()))
216+
.functionType(io.milvus.common.clientenum.FunctionType.fromName(functionSchema.getType().name()))
214217
.inputFieldNames(functionSchema.getInputFieldNamesList().stream().collect(Collectors.toList()))
215-
.outputFieldNames(functionSchema.getOutputFieldNamesList().stream().collect(Collectors.toList()))
216-
.build();
217-
return function;
218+
.outputFieldNames(functionSchema.getOutputFieldNamesList().stream().collect(Collectors.toList()));
219+
return builder.build();
218220
}
219221

220222
public static CreateCollectionReq.FieldSchema convertFieldReqToFieldSchema(AddFieldReq addFieldReq) {
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package io.milvus.v2.utils;
21+
22+
import io.milvus.grpc.*;
23+
import io.milvus.param.Constant;
24+
import io.milvus.v2.service.collection.request.AddFieldReq;
25+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
26+
import org.junit.jupiter.api.Assertions;
27+
import org.junit.jupiter.api.Test;
28+
29+
import java.util.*;
30+
31+
public class SchemaUtilsTest {
32+
private CreateCollectionReq.CollectionSchema buildSchema() {
33+
CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
34+
.enableDynamicField(true)
35+
.build();
36+
collectionSchema.addField(AddFieldReq.builder()
37+
.fieldName("id")
38+
.dataType(io.milvus.v2.common.DataType.Int64)
39+
.isPrimaryKey(Boolean.TRUE)
40+
.build());
41+
collectionSchema.addField(AddFieldReq.builder()
42+
.fieldName("bool_field")
43+
.dataType(io.milvus.v2.common.DataType.Bool)
44+
.build());
45+
collectionSchema.addField(AddFieldReq.builder()
46+
.fieldName("int8_field")
47+
.dataType(io.milvus.v2.common.DataType.Int8)
48+
.build());
49+
collectionSchema.addField(AddFieldReq.builder()
50+
.fieldName("int16_field")
51+
.dataType(io.milvus.v2.common.DataType.Int16)
52+
.build());
53+
collectionSchema.addField(AddFieldReq.builder()
54+
.fieldName("int32_field")
55+
.dataType(io.milvus.v2.common.DataType.Int32)
56+
.build());
57+
collectionSchema.addField(AddFieldReq.builder()
58+
.fieldName("int64_field")
59+
.dataType(io.milvus.v2.common.DataType.Int64)
60+
.defaultValue(888L)
61+
.build());
62+
collectionSchema.addField(AddFieldReq.builder()
63+
.fieldName("float_field")
64+
.dataType(io.milvus.v2.common.DataType.Float)
65+
.build());
66+
collectionSchema.addField(AddFieldReq.builder()
67+
.fieldName("double_field")
68+
.dataType(io.milvus.v2.common.DataType.Double)
69+
.build());
70+
Map<String, Object> analyzerParams = new HashMap<>();
71+
analyzerParams.put("type", "english");
72+
Map<String, Object> multiAnalyzerParams = new HashMap<>();
73+
multiAnalyzerParams.put("by_field", "language");
74+
collectionSchema.addField(AddFieldReq.builder()
75+
.fieldName("varchar_field")
76+
.dataType(io.milvus.v2.common.DataType.VarChar)
77+
.maxLength(1000)
78+
.enableAnalyzer(true)
79+
.analyzerParams(analyzerParams)
80+
.multiAnalyzerParams(multiAnalyzerParams)
81+
.enableMatch(true)
82+
.build());
83+
collectionSchema.addField(AddFieldReq.builder()
84+
.fieldName("json_field")
85+
.dataType(io.milvus.v2.common.DataType.JSON)
86+
.build());
87+
collectionSchema.addField(AddFieldReq.builder()
88+
.fieldName("arr_int_field")
89+
.dataType(io.milvus.v2.common.DataType.Array)
90+
.maxCapacity(50)
91+
.elementType(io.milvus.v2.common.DataType.Int32)
92+
.build());
93+
collectionSchema.addField(AddFieldReq.builder()
94+
.fieldName("arr_float_field")
95+
.dataType(io.milvus.v2.common.DataType.Array)
96+
.maxCapacity(20)
97+
.elementType(io.milvus.v2.common.DataType.Float)
98+
.build());
99+
collectionSchema.addField(AddFieldReq.builder()
100+
.fieldName("arr_varchar_field")
101+
.dataType(io.milvus.v2.common.DataType.Array)
102+
.maxCapacity(10)
103+
.elementType(io.milvus.v2.common.DataType.VarChar)
104+
.build());
105+
collectionSchema.addField(AddFieldReq.builder()
106+
.fieldName("float_vector_field")
107+
.dataType(io.milvus.v2.common.DataType.FloatVector)
108+
.dimension(128)
109+
.build());
110+
collectionSchema.addField(AddFieldReq.builder()
111+
.fieldName("binary_vector_field")
112+
.dataType(io.milvus.v2.common.DataType.BinaryVector)
113+
.dimension(64)
114+
.build());
115+
collectionSchema.addField(AddFieldReq.builder()
116+
.fieldName("float16_vector_field")
117+
.dataType(io.milvus.v2.common.DataType.Float16Vector)
118+
.dimension(256)
119+
.build());
120+
collectionSchema.addField(AddFieldReq.builder()
121+
.fieldName("bfloat16_vector_field")
122+
.dataType(io.milvus.v2.common.DataType.BFloat16Vector)
123+
.dimension(512)
124+
.build());
125+
collectionSchema.addField(AddFieldReq.builder()
126+
.fieldName("sparse_vector_field")
127+
.dataType(io.milvus.v2.common.DataType.SparseFloatVector)
128+
.build());
129+
130+
collectionSchema.addFunction(CreateCollectionReq.Function.builder()
131+
.functionType(io.milvus.common.clientenum.FunctionType.BM25)
132+
.name("function_bm25")
133+
.inputFieldNames(Collections.singletonList("varchar_field"))
134+
.outputFieldNames(Collections.singletonList("sparse_vector_field"))
135+
.build());
136+
137+
return collectionSchema;
138+
}
139+
140+
@Test
141+
void testConvertFromGrpcFunction() {
142+
for (FunctionType type : FunctionType.values()) {
143+
if (type == FunctionType.UNRECOGNIZED) {
144+
continue;
145+
}
146+
FunctionSchema functionSchema = FunctionSchema.newBuilder()
147+
.setName("abc")
148+
.setDescription("xxx")
149+
.setType(type)
150+
.addInputFieldNames("text")
151+
.addOutputFieldNames("vec")
152+
.build();
153+
154+
CreateCollectionReq.Function func = SchemaUtils.convertFromGrpcFunction(functionSchema);
155+
Assertions.assertEquals(func.getName(), "abc");
156+
Assertions.assertEquals(func.getDescription(), "xxx");
157+
Assertions.assertEquals(func.getFunctionType(), io.milvus.common.clientenum.FunctionType.fromName(type.name()));
158+
Assertions.assertEquals(func.getInputFieldNames().size(), 1);
159+
Assertions.assertEquals(func.getInputFieldNames().get(0), "text");
160+
Assertions.assertEquals(func.getOutputFieldNames().size(), 1);
161+
Assertions.assertEquals(func.getOutputFieldNames().get(0), "vec");
162+
}
163+
}
164+
165+
@Test
166+
void testConvertToGrpcFunction() {
167+
for (io.milvus.common.clientenum.FunctionType type : io.milvus.common.clientenum.FunctionType.values()) {
168+
CreateCollectionReq.Function function = CreateCollectionReq.Function.builder()
169+
.name("abc")
170+
.description("xxx")
171+
.functionType(type)
172+
.inputFieldNames(Collections.singletonList("text"))
173+
.outputFieldNames(Collections.singletonList("vec"))
174+
.build();
175+
176+
FunctionSchema functionSchema = SchemaUtils.convertToGrpcFunction(function);
177+
Assertions.assertEquals(functionSchema.getName(), "abc");
178+
Assertions.assertEquals(functionSchema.getDescription(), "xxx");
179+
Assertions.assertEquals(functionSchema.getType(), FunctionType.forNumber(type.getCode()));
180+
Assertions.assertEquals(functionSchema.getInputFieldNamesCount(), 1);
181+
Assertions.assertEquals(functionSchema.getInputFieldNames(0), "text");
182+
Assertions.assertEquals(functionSchema.getOutputFieldNamesCount(), 1);
183+
Assertions.assertEquals(functionSchema.getOutputFieldNames(0), "vec");
184+
}
185+
}
186+
187+
@Test
188+
void testConvertToGrpcFieldSchema() {
189+
CreateCollectionReq.CollectionSchema collectionSchema = buildSchema();
190+
List<CreateCollectionReq.FieldSchema> fieldSchemaList = collectionSchema.getFieldSchemaList();
191+
for (CreateCollectionReq.FieldSchema fieldSchema : fieldSchemaList) {
192+
FieldSchema rpcSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema);
193+
Assertions.assertEquals(rpcSchema.getName(), fieldSchema.getName());
194+
Assertions.assertEquals(rpcSchema.getDescription(), fieldSchema.getDescription());
195+
Assertions.assertEquals(rpcSchema.getDataType(), DataType.valueOf(fieldSchema.getDataType().name()));
196+
if (rpcSchema.getDataType() == DataType.Array) {
197+
Assertions.assertEquals(rpcSchema.getElementType(), DataType.valueOf(fieldSchema.getElementType().name()));
198+
}
199+
for (int i = 0; i < rpcSchema.getTypeParamsCount(); i++) {
200+
KeyValuePair pair = rpcSchema.getTypeParams(i);
201+
if (pair.getKey() == Constant.VECTOR_DIM) {
202+
Assertions.assertEquals(pair.getValue(), fieldSchema.getDimension().toString());
203+
} else if (pair.getKey() == Constant.VARCHAR_MAX_LENGTH) {
204+
Assertions.assertEquals(pair.getValue(), fieldSchema.getMaxLength().toString());
205+
} else if (pair.getKey() == Constant.ARRAY_MAX_CAPACITY) {
206+
Assertions.assertEquals(pair.getValue(), fieldSchema.getMaxCapacity().toString());
207+
}
208+
}
209+
Assertions.assertEquals(rpcSchema.getIsPrimaryKey(), fieldSchema.getIsPrimaryKey());
210+
Assertions.assertEquals(rpcSchema.getAutoID(), fieldSchema.getAutoID());
211+
Assertions.assertEquals(rpcSchema.getIsPartitionKey(), fieldSchema.getIsPartitionKey());
212+
Assertions.assertEquals(rpcSchema.getIsClusteringKey(), fieldSchema.getIsClusteringKey());
213+
Assertions.assertEquals(rpcSchema.getNullable(), fieldSchema.getIsNullable());
214+
215+
if (rpcSchema.getName().equals("int64_field")) {
216+
Assertions.assertEquals(rpcSchema.getDefaultValue().getLongData(), fieldSchema.getDefaultValue());
217+
} else {
218+
Assertions.assertEquals(rpcSchema.getDefaultValue(), io.milvus.grpc.ValueField.getDefaultInstance());
219+
}
220+
221+
if (rpcSchema.getName().equals("varchar_field")) {
222+
List<String> keys = new ArrayList<>();
223+
rpcSchema.getTypeParamsList().forEach((kv)-> keys.add(kv.getKey()));
224+
Assertions.assertTrue(keys.contains("enable_analyzer"));
225+
Assertions.assertTrue(keys.contains("enable_match"));
226+
Assertions.assertTrue(keys.contains("analyzer_params"));
227+
Assertions.assertTrue(keys.contains("multi_analyzer_params"));
228+
}
229+
}
230+
}
231+
232+
@Test
233+
void testConvertFromGrpcFieldSchema() {
234+
CreateCollectionReq.CollectionSchema collectionSchema = buildSchema();
235+
List<CreateCollectionReq.FieldSchema> fieldSchemaList = collectionSchema.getFieldSchemaList();
236+
for (CreateCollectionReq.FieldSchema fieldSchema : fieldSchemaList) {
237+
FieldSchema rpcSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema);
238+
239+
CreateCollectionReq.FieldSchema newSchema = SchemaUtils.convertFromGrpcFieldSchema(rpcSchema);
240+
Assertions.assertEquals(newSchema.getName(), fieldSchema.getName());
241+
Assertions.assertEquals(newSchema.getDescription(), fieldSchema.getDescription());
242+
Assertions.assertEquals(newSchema.getDataType(), fieldSchema.getDataType());
243+
if (rpcSchema.getDataType() == DataType.Array) {
244+
Assertions.assertEquals(newSchema.getElementType(), fieldSchema.getElementType());
245+
}
246+
247+
Map<String, String> originParams = fieldSchema.getTypeParams();
248+
if (originParams != null) {
249+
Map<String, String> typeParams = newSchema.getTypeParams();
250+
originParams.forEach((k ,v)->{
251+
Assertions.assertTrue(typeParams.containsKey(k));
252+
Assertions.assertEquals(typeParams.get(k), originParams.get(k));
253+
});
254+
}
255+
256+
Assertions.assertEquals(newSchema.getIsPrimaryKey(), fieldSchema.getIsPrimaryKey());
257+
Assertions.assertEquals(newSchema.getAutoID(), fieldSchema.getAutoID());
258+
Assertions.assertEquals(newSchema.getIsPartitionKey(), fieldSchema.getIsPartitionKey());
259+
Assertions.assertEquals(newSchema.getIsClusteringKey(), fieldSchema.getIsClusteringKey());
260+
Assertions.assertEquals(newSchema.getIsNullable(), fieldSchema.getIsNullable());
261+
262+
if (rpcSchema.getName().equals("int64_field")) {
263+
Assertions.assertEquals(newSchema.getDefaultValue(), fieldSchema.getDefaultValue());
264+
} else {
265+
Assertions.assertNull(newSchema.getDefaultValue());
266+
}
267+
268+
if (rpcSchema.getName().equals("varchar_field")) {
269+
Assertions.assertTrue(newSchema.getEnableAnalyzer());
270+
Assertions.assertTrue(newSchema.getEnableMatch());
271+
Assertions.assertEquals(newSchema.getAnalyzerParams(), fieldSchema.getAnalyzerParams());
272+
Assertions.assertEquals(newSchema.getMultiAnalyzerParams(), fieldSchema.getMultiAnalyzerParams());
273+
}
274+
}
275+
}
276+
}

0 commit comments

Comments
 (0)