|
70 | 70 | import io.milvus.v2.service.utility.response.GetQuerySegmentInfoResp; |
71 | 71 | import io.milvus.v2.service.vector.request.*; |
72 | 72 | import io.milvus.v2.service.vector.request.data.*; |
73 | | -import io.milvus.v2.service.vector.request.ranker.BoostRanker; |
74 | 73 | import io.milvus.v2.service.vector.request.ranker.RRFRanker; |
75 | 74 | import io.milvus.v2.service.vector.request.ranker.WeightedRanker; |
76 | 75 | import io.milvus.v2.service.vector.response.*; |
@@ -313,18 +312,6 @@ private long getRowCount(String dbName, String collectionName) { |
313 | 312 | return (long) queryResults.get(0).getEntity().get("count(*)"); |
314 | 313 | } |
315 | 314 |
|
316 | | - @Test |
317 | | - void testAAAA() { |
318 | | - CreateCollectionReq.Function func = CreateCollectionReq.Function.builder() |
319 | | - .name("XXX") |
320 | | - .build(); |
321 | | - BoostRanker ranker = BoostRanker.builder() |
322 | | - .name("AAA") |
323 | | - .weight(2.5f) |
324 | | - .build(); |
325 | | - } |
326 | | - |
327 | | - |
328 | 315 | @Test |
329 | 316 | void testFloatVectors() { |
330 | 317 | CheckHealthResp healthy = client.checkHealth(); |
@@ -1051,6 +1038,104 @@ void testInt8Vectors() { |
1051 | 1038 | } |
1052 | 1039 | } |
1053 | 1040 |
|
| 1041 | + @Test |
| 1042 | + void testArray() { |
| 1043 | + String randomCollectionName = generator.generate(10); |
| 1044 | + String pkField = "key"; |
| 1045 | + String vectorField = "vector"; |
| 1046 | + String arrayField = "array"; |
| 1047 | + int capacity = 10; |
| 1048 | + int varcharLength = 88; |
| 1049 | + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() |
| 1050 | + .build(); |
| 1051 | + collectionSchema.addField(AddFieldReq.builder() |
| 1052 | + .fieldName(pkField) |
| 1053 | + .dataType(DataType.Int64) |
| 1054 | + .isPrimaryKey(true) |
| 1055 | + .autoID(true) |
| 1056 | + .build()); |
| 1057 | + collectionSchema.addField(AddFieldReq.builder() |
| 1058 | + .fieldName(vectorField) |
| 1059 | + .dataType(DataType.FloatVector) |
| 1060 | + .dimension(DIMENSION) |
| 1061 | + .build()); |
| 1062 | + collectionSchema.addField(AddFieldReq.builder() |
| 1063 | + .fieldName(arrayField) |
| 1064 | + .description("dummy") |
| 1065 | + .dataType(DataType.Array) |
| 1066 | + .elementType(DataType.VarChar) |
| 1067 | + .maxCapacity(capacity) |
| 1068 | + .maxLength(varcharLength) |
| 1069 | + .build()); |
| 1070 | + |
| 1071 | + List<IndexParam> indexParams = new ArrayList<>(); |
| 1072 | + indexParams.add(IndexParam.builder() |
| 1073 | + .fieldName(vectorField) |
| 1074 | + .indexType(IndexParam.IndexType.HNSW) |
| 1075 | + .metricType(IndexParam.MetricType.COSINE) |
| 1076 | + .build()); |
| 1077 | + |
| 1078 | + client.dropCollection(DropCollectionReq.builder() |
| 1079 | + .collectionName(randomCollectionName) |
| 1080 | + .build()); |
| 1081 | + |
| 1082 | + CreateCollectionReq requestCreate = CreateCollectionReq.builder() |
| 1083 | + .collectionName(randomCollectionName) |
| 1084 | + .collectionSchema(collectionSchema) |
| 1085 | + .indexParams(indexParams) |
| 1086 | + .build(); |
| 1087 | + client.createCollection(requestCreate); |
| 1088 | + |
| 1089 | + // describe |
| 1090 | + DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder() |
| 1091 | + .collectionName(randomCollectionName) |
| 1092 | + .build()); |
| 1093 | + CreateCollectionReq.CollectionSchema descSchema = descResp.getCollectionSchema(); |
| 1094 | + Assertions.assertEquals(3, descSchema.getFieldSchemaList().size()); |
| 1095 | + CreateCollectionReq.FieldSchema arraySchema = descSchema.getFieldSchemaList().get(2); |
| 1096 | + Assertions.assertEquals(arrayField, arraySchema.getName()); |
| 1097 | + Assertions.assertEquals("dummy", arraySchema.getDescription()); |
| 1098 | + Assertions.assertEquals(DataType.Array, arraySchema.getDataType()); |
| 1099 | + Assertions.assertEquals(DataType.VarChar, arraySchema.getElementType()); |
| 1100 | + Assertions.assertEquals(capacity, arraySchema.getMaxCapacity()); |
| 1101 | + Assertions.assertEquals(varcharLength, arraySchema.getMaxLength()); |
| 1102 | + |
| 1103 | + // insert |
| 1104 | + List<JsonObject> rows = new ArrayList<>(); |
| 1105 | + int count = 20; |
| 1106 | + for (int i = 0; i < count; i++) { |
| 1107 | + JsonObject row = new JsonObject(); |
| 1108 | + row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector())); |
| 1109 | + List<String> strArray = new ArrayList<>(); |
| 1110 | + for (int k = i; k < capacity; k++) { |
| 1111 | + strArray.add(String.format("string-%d-%d", i, k)); |
| 1112 | + } |
| 1113 | + row.add(arrayField, JsonUtils.toJsonTree(strArray).getAsJsonArray()); |
| 1114 | + rows.add(row); |
| 1115 | + } |
| 1116 | + |
| 1117 | + InsertResp insertResp = client.insert(InsertReq.builder() |
| 1118 | + .collectionName(randomCollectionName) |
| 1119 | + .data(rows) |
| 1120 | + .build()); |
| 1121 | + Assertions.assertEquals(count, insertResp.getInsertCnt()); |
| 1122 | + |
| 1123 | + // query |
| 1124 | + QueryResp queryResp = client.query(QueryReq.builder() |
| 1125 | + .collectionName(randomCollectionName) |
| 1126 | + .filter(String.format("ARRAY_CONTAINS(%s, \"string-0-9\")", arrayField)) |
| 1127 | + .limit(5) |
| 1128 | + .consistencyLevel(ConsistencyLevel.STRONG) |
| 1129 | + .outputFields(Collections.singletonList(arrayField)) |
| 1130 | + .build()); |
| 1131 | + List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults(); |
| 1132 | + Assertions.assertEquals(1, queryResults.size()); |
| 1133 | + Assertions.assertTrue(queryResults.get(0).getEntity().containsKey(arrayField)); |
| 1134 | + Assertions.assertInstanceOf(List.class, queryResults.get(0).getEntity().get(arrayField)); |
| 1135 | + List<String> arr = (List<String>) queryResults.get(0).getEntity().get(arrayField); |
| 1136 | + Assertions.assertEquals(capacity, arr.size()); |
| 1137 | + } |
| 1138 | + |
1054 | 1139 | @Test |
1055 | 1140 | void testStruct() { |
1056 | 1141 | String randomCollectionName = generator.generate(10); |
|
0 commit comments