1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one
3+ * or more contributor license agreements. See the NOTICE file
4+ * distributed with this work for additional information
5+ * regarding copyright ownership. The ASF licenses this file
6+ * to you under the Apache License, Version 2.0 (the
7+ * "License"); you may not use this file except in compliance
8+ * with the License. You may obtain a copy of the License at
9+ *
10+ * http://www.apache.org/licenses/LICENSE-2.0
11+ *
12+ * Unless required by applicable law or agreed to in writing,
13+ * software distributed under the License is distributed on an
14+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+ * KIND, either express or implied. See the License for the
16+ * specific language governing permissions and limitations
17+ * under the License.
18+ */
19+
20+ package io .milvus .v2 .utils ;
21+
22+ import io .milvus .grpc .*;
23+ import io .milvus .param .Constant ;
24+ import io .milvus .v2 .service .collection .request .AddFieldReq ;
25+ import io .milvus .v2 .service .collection .request .CreateCollectionReq ;
26+ import org .junit .jupiter .api .Assertions ;
27+ import org .junit .jupiter .api .Test ;
28+
29+ import java .util .*;
30+
31+ public class SchemaUtilsTest {
32+ private CreateCollectionReq .CollectionSchema buildSchema () {
33+ CreateCollectionReq .CollectionSchema collectionSchema = CreateCollectionReq .CollectionSchema .builder ()
34+ .enableDynamicField (true )
35+ .build ();
36+ collectionSchema .addField (AddFieldReq .builder ()
37+ .fieldName ("id" )
38+ .dataType (io .milvus .v2 .common .DataType .Int64 )
39+ .isPrimaryKey (Boolean .TRUE )
40+ .build ());
41+ collectionSchema .addField (AddFieldReq .builder ()
42+ .fieldName ("bool_field" )
43+ .dataType (io .milvus .v2 .common .DataType .Bool )
44+ .build ());
45+ collectionSchema .addField (AddFieldReq .builder ()
46+ .fieldName ("int8_field" )
47+ .dataType (io .milvus .v2 .common .DataType .Int8 )
48+ .build ());
49+ collectionSchema .addField (AddFieldReq .builder ()
50+ .fieldName ("int16_field" )
51+ .dataType (io .milvus .v2 .common .DataType .Int16 )
52+ .build ());
53+ collectionSchema .addField (AddFieldReq .builder ()
54+ .fieldName ("int32_field" )
55+ .dataType (io .milvus .v2 .common .DataType .Int32 )
56+ .build ());
57+ collectionSchema .addField (AddFieldReq .builder ()
58+ .fieldName ("int64_field" )
59+ .dataType (io .milvus .v2 .common .DataType .Int64 )
60+ .defaultValue (888L )
61+ .build ());
62+ collectionSchema .addField (AddFieldReq .builder ()
63+ .fieldName ("float_field" )
64+ .dataType (io .milvus .v2 .common .DataType .Float )
65+ .build ());
66+ collectionSchema .addField (AddFieldReq .builder ()
67+ .fieldName ("double_field" )
68+ .dataType (io .milvus .v2 .common .DataType .Double )
69+ .build ());
70+ Map <String , Object > analyzerParams = new HashMap <>();
71+ analyzerParams .put ("type" , "english" );
72+ Map <String , Object > multiAnalyzerParams = new HashMap <>();
73+ multiAnalyzerParams .put ("by_field" , "language" );
74+ collectionSchema .addField (AddFieldReq .builder ()
75+ .fieldName ("varchar_field" )
76+ .dataType (io .milvus .v2 .common .DataType .VarChar )
77+ .maxLength (1000 )
78+ .enableAnalyzer (true )
79+ .analyzerParams (analyzerParams )
80+ .multiAnalyzerParams (multiAnalyzerParams )
81+ .enableMatch (true )
82+ .build ());
83+ collectionSchema .addField (AddFieldReq .builder ()
84+ .fieldName ("json_field" )
85+ .dataType (io .milvus .v2 .common .DataType .JSON )
86+ .build ());
87+ collectionSchema .addField (AddFieldReq .builder ()
88+ .fieldName ("arr_int_field" )
89+ .dataType (io .milvus .v2 .common .DataType .Array )
90+ .maxCapacity (50 )
91+ .elementType (io .milvus .v2 .common .DataType .Int32 )
92+ .build ());
93+ collectionSchema .addField (AddFieldReq .builder ()
94+ .fieldName ("arr_float_field" )
95+ .dataType (io .milvus .v2 .common .DataType .Array )
96+ .maxCapacity (20 )
97+ .elementType (io .milvus .v2 .common .DataType .Float )
98+ .build ());
99+ collectionSchema .addField (AddFieldReq .builder ()
100+ .fieldName ("arr_varchar_field" )
101+ .dataType (io .milvus .v2 .common .DataType .Array )
102+ .maxCapacity (10 )
103+ .elementType (io .milvus .v2 .common .DataType .VarChar )
104+ .build ());
105+ collectionSchema .addField (AddFieldReq .builder ()
106+ .fieldName ("float_vector_field" )
107+ .dataType (io .milvus .v2 .common .DataType .FloatVector )
108+ .dimension (128 )
109+ .build ());
110+ collectionSchema .addField (AddFieldReq .builder ()
111+ .fieldName ("binary_vector_field" )
112+ .dataType (io .milvus .v2 .common .DataType .BinaryVector )
113+ .dimension (64 )
114+ .build ());
115+ collectionSchema .addField (AddFieldReq .builder ()
116+ .fieldName ("float16_vector_field" )
117+ .dataType (io .milvus .v2 .common .DataType .Float16Vector )
118+ .dimension (256 )
119+ .build ());
120+ collectionSchema .addField (AddFieldReq .builder ()
121+ .fieldName ("bfloat16_vector_field" )
122+ .dataType (io .milvus .v2 .common .DataType .BFloat16Vector )
123+ .dimension (512 )
124+ .build ());
125+ collectionSchema .addField (AddFieldReq .builder ()
126+ .fieldName ("sparse_vector_field" )
127+ .dataType (io .milvus .v2 .common .DataType .SparseFloatVector )
128+ .build ());
129+
130+ collectionSchema .addFunction (CreateCollectionReq .Function .builder ()
131+ .functionType (io .milvus .common .clientenum .FunctionType .BM25 )
132+ .name ("function_bm25" )
133+ .inputFieldNames (Collections .singletonList ("varchar_field" ))
134+ .outputFieldNames (Collections .singletonList ("sparse_vector_field" ))
135+ .build ());
136+
137+ return collectionSchema ;
138+ }
139+
140+ @ Test
141+ void testConvertFromGrpcFunction () {
142+ for (FunctionType type : FunctionType .values ()) {
143+ if (type == FunctionType .UNRECOGNIZED ) {
144+ continue ;
145+ }
146+ FunctionSchema functionSchema = FunctionSchema .newBuilder ()
147+ .setName ("abc" )
148+ .setDescription ("xxx" )
149+ .setType (type )
150+ .addInputFieldNames ("text" )
151+ .addOutputFieldNames ("vec" )
152+ .build ();
153+
154+ CreateCollectionReq .Function func = SchemaUtils .convertFromGrpcFunction (functionSchema );
155+ Assertions .assertEquals (func .getName (), "abc" );
156+ Assertions .assertEquals (func .getDescription (), "xxx" );
157+ Assertions .assertEquals (func .getFunctionType (), io .milvus .common .clientenum .FunctionType .fromName (type .name ()));
158+ Assertions .assertEquals (func .getInputFieldNames ().size (), 1 );
159+ Assertions .assertEquals (func .getInputFieldNames ().get (0 ), "text" );
160+ Assertions .assertEquals (func .getOutputFieldNames ().size (), 1 );
161+ Assertions .assertEquals (func .getOutputFieldNames ().get (0 ), "vec" );
162+ }
163+ }
164+
165+ @ Test
166+ void testConvertToGrpcFunction () {
167+ for (io .milvus .common .clientenum .FunctionType type : io .milvus .common .clientenum .FunctionType .values ()) {
168+ CreateCollectionReq .Function function = CreateCollectionReq .Function .builder ()
169+ .name ("abc" )
170+ .description ("xxx" )
171+ .functionType (type )
172+ .inputFieldNames (Collections .singletonList ("text" ))
173+ .outputFieldNames (Collections .singletonList ("vec" ))
174+ .build ();
175+
176+ FunctionSchema functionSchema = SchemaUtils .convertToGrpcFunction (function );
177+ Assertions .assertEquals (functionSchema .getName (), "abc" );
178+ Assertions .assertEquals (functionSchema .getDescription (), "xxx" );
179+ Assertions .assertEquals (functionSchema .getType (), FunctionType .forNumber (type .getCode ()));
180+ Assertions .assertEquals (functionSchema .getInputFieldNamesCount (), 1 );
181+ Assertions .assertEquals (functionSchema .getInputFieldNames (0 ), "text" );
182+ Assertions .assertEquals (functionSchema .getOutputFieldNamesCount (), 1 );
183+ Assertions .assertEquals (functionSchema .getOutputFieldNames (0 ), "vec" );
184+ }
185+ }
186+
187+ @ Test
188+ void testConvertToGrpcFieldSchema () {
189+ CreateCollectionReq .CollectionSchema collectionSchema = buildSchema ();
190+ List <CreateCollectionReq .FieldSchema > fieldSchemaList = collectionSchema .getFieldSchemaList ();
191+ for (CreateCollectionReq .FieldSchema fieldSchema : fieldSchemaList ) {
192+ FieldSchema rpcSchema = SchemaUtils .convertToGrpcFieldSchema (fieldSchema );
193+ Assertions .assertEquals (rpcSchema .getName (), fieldSchema .getName ());
194+ Assertions .assertEquals (rpcSchema .getDescription (), fieldSchema .getDescription ());
195+ Assertions .assertEquals (rpcSchema .getDataType (), DataType .valueOf (fieldSchema .getDataType ().name ()));
196+ if (rpcSchema .getDataType () == DataType .Array ) {
197+ Assertions .assertEquals (rpcSchema .getElementType (), DataType .valueOf (fieldSchema .getElementType ().name ()));
198+ }
199+ for (int i = 0 ; i < rpcSchema .getTypeParamsCount (); i ++) {
200+ KeyValuePair pair = rpcSchema .getTypeParams (i );
201+ if (pair .getKey () == Constant .VECTOR_DIM ) {
202+ Assertions .assertEquals (pair .getValue (), fieldSchema .getDimension ().toString ());
203+ } else if (pair .getKey () == Constant .VARCHAR_MAX_LENGTH ) {
204+ Assertions .assertEquals (pair .getValue (), fieldSchema .getMaxLength ().toString ());
205+ } else if (pair .getKey () == Constant .ARRAY_MAX_CAPACITY ) {
206+ Assertions .assertEquals (pair .getValue (), fieldSchema .getMaxCapacity ().toString ());
207+ }
208+ }
209+ Assertions .assertEquals (rpcSchema .getIsPrimaryKey (), fieldSchema .getIsPrimaryKey ());
210+ Assertions .assertEquals (rpcSchema .getAutoID (), fieldSchema .getAutoID ());
211+ Assertions .assertEquals (rpcSchema .getIsPartitionKey (), fieldSchema .getIsPartitionKey ());
212+ Assertions .assertEquals (rpcSchema .getIsClusteringKey (), fieldSchema .getIsClusteringKey ());
213+ Assertions .assertEquals (rpcSchema .getNullable (), fieldSchema .getIsNullable ());
214+
215+ if (rpcSchema .getName ().equals ("int64_field" )) {
216+ Assertions .assertEquals (rpcSchema .getDefaultValue ().getLongData (), fieldSchema .getDefaultValue ());
217+ } else {
218+ Assertions .assertEquals (rpcSchema .getDefaultValue (), io .milvus .grpc .ValueField .getDefaultInstance ());
219+ }
220+
221+ if (rpcSchema .getName ().equals ("varchar_field" )) {
222+ List <String > keys = new ArrayList <>();
223+ rpcSchema .getTypeParamsList ().forEach ((kv )-> keys .add (kv .getKey ()));
224+ Assertions .assertTrue (keys .contains ("enable_analyzer" ));
225+ Assertions .assertTrue (keys .contains ("enable_match" ));
226+ Assertions .assertTrue (keys .contains ("analyzer_params" ));
227+ Assertions .assertTrue (keys .contains ("multi_analyzer_params" ));
228+ }
229+ }
230+ }
231+
232+ @ Test
233+ void testConvertFromGrpcFieldSchema () {
234+ CreateCollectionReq .CollectionSchema collectionSchema = buildSchema ();
235+ List <CreateCollectionReq .FieldSchema > fieldSchemaList = collectionSchema .getFieldSchemaList ();
236+ for (CreateCollectionReq .FieldSchema fieldSchema : fieldSchemaList ) {
237+ FieldSchema rpcSchema = SchemaUtils .convertToGrpcFieldSchema (fieldSchema );
238+
239+ CreateCollectionReq .FieldSchema newSchema = SchemaUtils .convertFromGrpcFieldSchema (rpcSchema );
240+ Assertions .assertEquals (newSchema .getName (), fieldSchema .getName ());
241+ Assertions .assertEquals (newSchema .getDescription (), fieldSchema .getDescription ());
242+ Assertions .assertEquals (newSchema .getDataType (), fieldSchema .getDataType ());
243+ if (rpcSchema .getDataType () == DataType .Array ) {
244+ Assertions .assertEquals (newSchema .getElementType (), fieldSchema .getElementType ());
245+ }
246+
247+ Map <String , String > originParams = fieldSchema .getTypeParams ();
248+ if (originParams != null ) {
249+ Map <String , String > typeParams = newSchema .getTypeParams ();
250+ originParams .forEach ((k ,v )->{
251+ Assertions .assertTrue (typeParams .containsKey (k ));
252+ Assertions .assertEquals (typeParams .get (k ), originParams .get (k ));
253+ });
254+ }
255+
256+ Assertions .assertEquals (newSchema .getIsPrimaryKey (), fieldSchema .getIsPrimaryKey ());
257+ Assertions .assertEquals (newSchema .getAutoID (), fieldSchema .getAutoID ());
258+ Assertions .assertEquals (newSchema .getIsPartitionKey (), fieldSchema .getIsPartitionKey ());
259+ Assertions .assertEquals (newSchema .getIsClusteringKey (), fieldSchema .getIsClusteringKey ());
260+ Assertions .assertEquals (newSchema .getIsNullable (), fieldSchema .getIsNullable ());
261+
262+ if (rpcSchema .getName ().equals ("int64_field" )) {
263+ Assertions .assertEquals (newSchema .getDefaultValue (), fieldSchema .getDefaultValue ());
264+ } else {
265+ Assertions .assertNull (newSchema .getDefaultValue ());
266+ }
267+
268+ if (rpcSchema .getName ().equals ("varchar_field" )) {
269+ Assertions .assertTrue (newSchema .getEnableAnalyzer ());
270+ Assertions .assertTrue (newSchema .getEnableMatch ());
271+ Assertions .assertEquals (newSchema .getAnalyzerParams (), fieldSchema .getAnalyzerParams ());
272+ Assertions .assertEquals (newSchema .getMultiAnalyzerParams (), fieldSchema .getMultiAnalyzerParams ());
273+ }
274+ }
275+ }
276+ }
0 commit comments