Skip to content

Commit 2e07edc

Browse files
authored
Fix a bug of FunctionType.TEXTEMBEDDING (#1535)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 4326fd0 commit 2e07edc

3 files changed

Lines changed: 321 additions & 10 deletions

File tree

sdk-core/src/main/java/io/milvus/common/clientenum/FunctionType.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,38 @@
2222
import lombok.Getter;
2323

2424
public enum FunctionType {
25-
UNKNOWN(0),
25+
UNKNOWN("Unknown", 0), // in milvus-proto, the name is "Unknown"
2626
BM25(1),
27-
TEXTEMBEDDING(2),
27+
TEXTEMBEDDING("TextEmbedding", 2), // in milvus-proto, the name is "TextEmbedding"
2828
RERANK(3),
2929
;
3030

31+
private final String name;
32+
3133
@Getter
3234
private final int code;
3335

34-
FunctionType(int i) {
35-
code = i;
36+
FunctionType(){
37+
this.name = this.name();
38+
this.code = this.ordinal();
39+
}
40+
41+
FunctionType(int code){
42+
this.name = this.name();
43+
this.code = code;
44+
}
45+
46+
FunctionType(String name, int code){
47+
this.name = name;
48+
this.code = code;
49+
}
50+
51+
public static FunctionType fromName(String name) {
52+
for (FunctionType type : FunctionType.values()) {
53+
if (type.name().equals(name)) {
54+
return type;
55+
}
56+
}
57+
return null;
3658
}
3759
}

sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ public static CreateCollectionReq.CollectionSchema convertFromGrpcCollectionSche
166166
return collectionSchema;
167167
}
168168

169-
private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldSchema fieldSchema) {
169+
public static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldSchema fieldSchema) {
170170
CreateCollectionReq.FieldSchema schema = CreateCollectionReq.FieldSchema.builder()
171171
.name(fieldSchema.getName())
172172
.description(fieldSchema.getDescription())
@@ -196,6 +196,9 @@ private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldS
196196
} else if(keyValuePair.getKey().equals("analyzer_params")){
197197
Map<String, Object> params = JsonUtils.fromJson(keyValuePair.getValue(), new TypeToken<Map<String, Object>>() {}.getType());
198198
schema.setAnalyzerParams(params);
199+
} else if(keyValuePair.getKey().equals("multi_analyzer_params")){
200+
Map<String, Object> params = JsonUtils.fromJson(keyValuePair.getValue(), new TypeToken<Map<String, Object>>() {}.getType());
201+
schema.setMultiAnalyzerParams(params);
199202
}
200203
} catch (Exception e) {
201204
/**
@@ -212,14 +215,15 @@ private static CreateCollectionReq.FieldSchema convertFromGrpcFieldSchema(FieldS
212215
}
213216

214217
public static CreateCollectionReq.Function convertFromGrpcFunction(FunctionSchema functionSchema) {
215-
CreateCollectionReq.Function function = CreateCollectionReq.Function.builder()
218+
CreateCollectionReq.Function.FunctionBuilder builder = CreateCollectionReq.Function.builder()
216219
.name(functionSchema.getName())
217220
.description(functionSchema.getDescription())
218-
.functionType(io.milvus.common.clientenum.FunctionType.valueOf(functionSchema.getType().name()))
221+
.functionType(io.milvus.common.clientenum.FunctionType.fromName(functionSchema.getType().name()))
219222
.inputFieldNames(functionSchema.getInputFieldNamesList().stream().collect(Collectors.toList()))
220-
.outputFieldNames(functionSchema.getOutputFieldNamesList().stream().collect(Collectors.toList()))
221-
.build();
222-
return function;
223+
.outputFieldNames(functionSchema.getOutputFieldNamesList().stream().collect(Collectors.toList()));
224+
List<KeyValuePair> pairs = functionSchema.getParamsList();
225+
pairs.forEach((kv)->builder.param(kv.getKey(), kv.getValue()));
226+
return builder.build();
223227
}
224228

225229
public static CreateCollectionReq.FieldSchema convertFieldReqToFieldSchema(AddFieldReq addFieldReq) {
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package io.milvus.v2.utils;
21+
22+
import io.milvus.grpc.*;
23+
import io.milvus.param.Constant;
24+
import io.milvus.v2.service.collection.request.AddFieldReq;
25+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
26+
import org.junit.jupiter.api.Assertions;
27+
import org.junit.jupiter.api.Test;
28+
29+
import java.util.*;
30+
31+
public class SchemaUtilsTest {
32+
private CreateCollectionReq.CollectionSchema buildSchema() {
33+
CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
34+
.enableDynamicField(true)
35+
.build();
36+
collectionSchema.addField(AddFieldReq.builder()
37+
.fieldName("id")
38+
.dataType(io.milvus.v2.common.DataType.Int64)
39+
.isPrimaryKey(Boolean.TRUE)
40+
.build());
41+
collectionSchema.addField(AddFieldReq.builder()
42+
.fieldName("bool_field")
43+
.dataType(io.milvus.v2.common.DataType.Bool)
44+
.build());
45+
collectionSchema.addField(AddFieldReq.builder()
46+
.fieldName("int8_field")
47+
.dataType(io.milvus.v2.common.DataType.Int8)
48+
.build());
49+
collectionSchema.addField(AddFieldReq.builder()
50+
.fieldName("int16_field")
51+
.dataType(io.milvus.v2.common.DataType.Int16)
52+
.build());
53+
collectionSchema.addField(AddFieldReq.builder()
54+
.fieldName("int32_field")
55+
.dataType(io.milvus.v2.common.DataType.Int32)
56+
.build());
57+
collectionSchema.addField(AddFieldReq.builder()
58+
.fieldName("int64_field")
59+
.dataType(io.milvus.v2.common.DataType.Int64)
60+
.defaultValue(888L)
61+
.build());
62+
collectionSchema.addField(AddFieldReq.builder()
63+
.fieldName("float_field")
64+
.dataType(io.milvus.v2.common.DataType.Float)
65+
.build());
66+
collectionSchema.addField(AddFieldReq.builder()
67+
.fieldName("double_field")
68+
.dataType(io.milvus.v2.common.DataType.Double)
69+
.build());
70+
Map<String, Object> analyzerParams = new HashMap<>();
71+
analyzerParams.put("type", "english");
72+
Map<String, Object> multiAnalyzerParams = new HashMap<>();
73+
multiAnalyzerParams.put("by_field", "language");
74+
collectionSchema.addField(AddFieldReq.builder()
75+
.fieldName("varchar_field")
76+
.dataType(io.milvus.v2.common.DataType.VarChar)
77+
.maxLength(1000)
78+
.enableAnalyzer(true)
79+
.analyzerParams(analyzerParams)
80+
.multiAnalyzerParams(multiAnalyzerParams)
81+
.enableMatch(true)
82+
.build());
83+
collectionSchema.addField(AddFieldReq.builder()
84+
.fieldName("json_field")
85+
.dataType(io.milvus.v2.common.DataType.JSON)
86+
.build());
87+
collectionSchema.addField(AddFieldReq.builder()
88+
.fieldName("arr_int_field")
89+
.dataType(io.milvus.v2.common.DataType.Array)
90+
.maxCapacity(50)
91+
.elementType(io.milvus.v2.common.DataType.Int32)
92+
.build());
93+
collectionSchema.addField(AddFieldReq.builder()
94+
.fieldName("arr_float_field")
95+
.dataType(io.milvus.v2.common.DataType.Array)
96+
.maxCapacity(20)
97+
.elementType(io.milvus.v2.common.DataType.Float)
98+
.build());
99+
collectionSchema.addField(AddFieldReq.builder()
100+
.fieldName("arr_varchar_field")
101+
.dataType(io.milvus.v2.common.DataType.Array)
102+
.maxCapacity(10)
103+
.elementType(io.milvus.v2.common.DataType.VarChar)
104+
.build());
105+
collectionSchema.addField(AddFieldReq.builder()
106+
.fieldName("float_vector_field")
107+
.dataType(io.milvus.v2.common.DataType.FloatVector)
108+
.dimension(128)
109+
.build());
110+
collectionSchema.addField(AddFieldReq.builder()
111+
.fieldName("binary_vector_field")
112+
.dataType(io.milvus.v2.common.DataType.BinaryVector)
113+
.dimension(64)
114+
.build());
115+
collectionSchema.addField(AddFieldReq.builder()
116+
.fieldName("float16_vector_field")
117+
.dataType(io.milvus.v2.common.DataType.Float16Vector)
118+
.dimension(256)
119+
.build());
120+
collectionSchema.addField(AddFieldReq.builder()
121+
.fieldName("bfloat16_vector_field")
122+
.dataType(io.milvus.v2.common.DataType.BFloat16Vector)
123+
.dimension(512)
124+
.build());
125+
collectionSchema.addField(AddFieldReq.builder()
126+
.fieldName("sparse_vector_field")
127+
.dataType(io.milvus.v2.common.DataType.SparseFloatVector)
128+
.build());
129+
130+
collectionSchema.addFunction(CreateCollectionReq.Function.builder()
131+
.functionType(io.milvus.common.clientenum.FunctionType.BM25)
132+
.name("function_bm25")
133+
.inputFieldNames(Collections.singletonList("varchar_field"))
134+
.outputFieldNames(Collections.singletonList("sparse_vector_field"))
135+
.build());
136+
137+
return collectionSchema;
138+
}
139+
140+
@Test
141+
void testConvertFromGrpcFunction() {
142+
for (FunctionType type : FunctionType.values()) {
143+
if (type == FunctionType.UNRECOGNIZED) {
144+
continue;
145+
}
146+
FunctionSchema functionSchema = FunctionSchema.newBuilder()
147+
.setName("abc")
148+
.setDescription("xxx")
149+
.setType(type)
150+
.addInputFieldNames("text")
151+
.addOutputFieldNames("vec")
152+
.addParams(KeyValuePair.newBuilder().setKey("provider").setValue("openai").build())
153+
.build();
154+
155+
CreateCollectionReq.Function func = SchemaUtils.convertFromGrpcFunction(functionSchema);
156+
Assertions.assertEquals(func.getName(), "abc");
157+
Assertions.assertEquals(func.getDescription(), "xxx");
158+
Assertions.assertEquals(func.getFunctionType(), io.milvus.common.clientenum.FunctionType.fromName(type.name()));
159+
Assertions.assertEquals(func.getInputFieldNames().size(), 1);
160+
Assertions.assertEquals(func.getInputFieldNames().get(0), "text");
161+
Assertions.assertEquals(func.getOutputFieldNames().size(), 1);
162+
Assertions.assertEquals(func.getOutputFieldNames().get(0), "vec");
163+
Map<String, String> params = func.getParams();
164+
Assertions.assertTrue(params.containsKey("provider"));
165+
Assertions.assertEquals(params.get("provider"), "openai");
166+
}
167+
}
168+
169+
@Test
170+
void testConvertToGrpcFunction() {
171+
for (io.milvus.common.clientenum.FunctionType type : io.milvus.common.clientenum.FunctionType.values()) {
172+
CreateCollectionReq.Function function = CreateCollectionReq.Function.builder()
173+
.name("abc")
174+
.description("xxx")
175+
.functionType(type)
176+
.inputFieldNames(Collections.singletonList("text"))
177+
.outputFieldNames(Collections.singletonList("vec"))
178+
.param("provider", "openai")
179+
.build();
180+
181+
FunctionSchema functionSchema = SchemaUtils.convertToGrpcFunction(function);
182+
Assertions.assertEquals(functionSchema.getName(), "abc");
183+
Assertions.assertEquals(functionSchema.getDescription(), "xxx");
184+
Assertions.assertEquals(functionSchema.getType(), FunctionType.forNumber(type.getCode()));
185+
Assertions.assertEquals(functionSchema.getInputFieldNamesCount(), 1);
186+
Assertions.assertEquals(functionSchema.getInputFieldNames(0), "text");
187+
Assertions.assertEquals(functionSchema.getOutputFieldNamesCount(), 1);
188+
Assertions.assertEquals(functionSchema.getOutputFieldNames(0), "vec");
189+
List<KeyValuePair> pairs = functionSchema.getParamsList();
190+
Assertions.assertEquals(pairs.size(), 1);
191+
Assertions.assertEquals(pairs.get(0).getKey(), "provider");
192+
Assertions.assertEquals(pairs.get(0).getValue(), "openai");
193+
}
194+
}
195+
196+
@Test
197+
void testConvertToGrpcFieldSchema() {
198+
CreateCollectionReq.CollectionSchema collectionSchema = buildSchema();
199+
List<CreateCollectionReq.FieldSchema> fieldSchemaList = collectionSchema.getFieldSchemaList();
200+
for (CreateCollectionReq.FieldSchema fieldSchema : fieldSchemaList) {
201+
FieldSchema rpcSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema);
202+
Assertions.assertEquals(rpcSchema.getName(), fieldSchema.getName());
203+
Assertions.assertEquals(rpcSchema.getDescription(), fieldSchema.getDescription());
204+
Assertions.assertEquals(rpcSchema.getDataType(), DataType.valueOf(fieldSchema.getDataType().name()));
205+
if (rpcSchema.getDataType() == DataType.Array) {
206+
Assertions.assertEquals(rpcSchema.getElementType(), DataType.valueOf(fieldSchema.getElementType().name()));
207+
}
208+
for (int i = 0; i < rpcSchema.getTypeParamsCount(); i++) {
209+
KeyValuePair pair = rpcSchema.getTypeParams(i);
210+
if (pair.getKey() == Constant.VECTOR_DIM) {
211+
Assertions.assertEquals(pair.getValue(), fieldSchema.getDimension().toString());
212+
} else if (pair.getKey() == Constant.VARCHAR_MAX_LENGTH) {
213+
Assertions.assertEquals(pair.getValue(), fieldSchema.getMaxLength().toString());
214+
} else if (pair.getKey() == Constant.ARRAY_MAX_CAPACITY) {
215+
Assertions.assertEquals(pair.getValue(), fieldSchema.getMaxCapacity().toString());
216+
}
217+
}
218+
Assertions.assertEquals(rpcSchema.getIsPrimaryKey(), fieldSchema.getIsPrimaryKey());
219+
Assertions.assertEquals(rpcSchema.getAutoID(), fieldSchema.getAutoID());
220+
Assertions.assertEquals(rpcSchema.getIsPartitionKey(), fieldSchema.getIsPartitionKey());
221+
Assertions.assertEquals(rpcSchema.getIsClusteringKey(), fieldSchema.getIsClusteringKey());
222+
Assertions.assertEquals(rpcSchema.getNullable(), fieldSchema.getIsNullable());
223+
224+
if (rpcSchema.getName().equals("int64_field")) {
225+
Assertions.assertEquals(rpcSchema.getDefaultValue().getLongData(), fieldSchema.getDefaultValue());
226+
} else {
227+
Assertions.assertEquals(rpcSchema.getDefaultValue(), io.milvus.grpc.ValueField.getDefaultInstance());
228+
}
229+
230+
if (rpcSchema.getName().equals("varchar_field")) {
231+
List<String> keys = new ArrayList<>();
232+
rpcSchema.getTypeParamsList().forEach((kv)-> keys.add(kv.getKey()));
233+
Assertions.assertTrue(keys.contains("enable_analyzer"));
234+
Assertions.assertTrue(keys.contains("enable_match"));
235+
Assertions.assertTrue(keys.contains("analyzer_params"));
236+
Assertions.assertTrue(keys.contains("multi_analyzer_params"));
237+
}
238+
}
239+
}
240+
241+
@Test
242+
void testConvertFromGrpcFieldSchema() {
243+
CreateCollectionReq.CollectionSchema collectionSchema = buildSchema();
244+
List<CreateCollectionReq.FieldSchema> fieldSchemaList = collectionSchema.getFieldSchemaList();
245+
for (CreateCollectionReq.FieldSchema fieldSchema : fieldSchemaList) {
246+
FieldSchema rpcSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema);
247+
248+
CreateCollectionReq.FieldSchema newSchema = SchemaUtils.convertFromGrpcFieldSchema(rpcSchema);
249+
Assertions.assertEquals(newSchema.getName(), fieldSchema.getName());
250+
Assertions.assertEquals(newSchema.getDescription(), fieldSchema.getDescription());
251+
Assertions.assertEquals(newSchema.getDataType(), fieldSchema.getDataType());
252+
if (rpcSchema.getDataType() == DataType.Array) {
253+
Assertions.assertEquals(newSchema.getElementType(), fieldSchema.getElementType());
254+
}
255+
256+
Map<String, String> originParams = fieldSchema.getTypeParams();
257+
if (originParams != null) {
258+
Map<String, String> typeParams = newSchema.getTypeParams();
259+
originParams.forEach((k ,v)->{
260+
Assertions.assertTrue(typeParams.containsKey(k));
261+
Assertions.assertEquals(typeParams.get(k), originParams.get(k));
262+
});
263+
}
264+
265+
Assertions.assertEquals(newSchema.getIsPrimaryKey(), fieldSchema.getIsPrimaryKey());
266+
Assertions.assertEquals(newSchema.getAutoID(), fieldSchema.getAutoID());
267+
Assertions.assertEquals(newSchema.getIsPartitionKey(), fieldSchema.getIsPartitionKey());
268+
Assertions.assertEquals(newSchema.getIsClusteringKey(), fieldSchema.getIsClusteringKey());
269+
Assertions.assertEquals(newSchema.getIsNullable(), fieldSchema.getIsNullable());
270+
271+
if (rpcSchema.getName().equals("int64_field")) {
272+
Assertions.assertEquals(newSchema.getDefaultValue(), fieldSchema.getDefaultValue());
273+
} else {
274+
Assertions.assertNull(newSchema.getDefaultValue());
275+
}
276+
277+
if (rpcSchema.getName().equals("varchar_field")) {
278+
Assertions.assertTrue(newSchema.getEnableAnalyzer());
279+
Assertions.assertTrue(newSchema.getEnableMatch());
280+
Assertions.assertEquals(newSchema.getAnalyzerParams(), fieldSchema.getAnalyzerParams());
281+
Assertions.assertEquals(newSchema.getMultiAnalyzerParams(), fieldSchema.getMultiAnalyzerParams());
282+
}
283+
}
284+
}
285+
}

0 commit comments

Comments
 (0)