Skip to content

Commit f4ec5a1

Browse files
committed
BulkWriter supports Int8Vector (milvus-io#1440)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 23c84fe commit f4ec5a1

8 files changed

Lines changed: 55 additions & 11 deletions

File tree

examples/src/main/java/io/milvus/v1/CommonUtils.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,27 @@ public static List<ByteBuffer> generateFloat16Vectors(int dimension, int count,
262262
return vectors;
263263
}
264264

265+
/////////////////////////////////////////////////////////////////////////////////////////////////////
266+
public static ByteBuffer generateInt8Vector(int dimension) {
267+
Random ran = new Random();
268+
int byteCount = dimension;
269+
// binary vector doesn't care endian since each byte is independent
270+
ByteBuffer vector = ByteBuffer.allocate(byteCount);
271+
for (int i = 0; i < byteCount; ++i) {
272+
vector.put((byte) (ran.nextInt(256) - 128));
273+
}
274+
return vector;
275+
}
276+
277+
public static List<ByteBuffer> generateInt8Vectors(int dimension, int count) {
278+
List<ByteBuffer> vectors = new ArrayList<>();
279+
for (int n = 0; n < count; ++n) {
280+
ByteBuffer vector = generateInt8Vector(dimension);
281+
vectors.add(vector);
282+
}
283+
return vectors;
284+
}
285+
265286
/////////////////////////////////////////////////////////////////////////////////////////////////////
266287
public static SortedMap<Long, Float> generateSparseVector() {
267288
Random ran = new Random();

examples/src/main/java/io/milvus/v2/BulkWriterExample.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
373373
// vector field
374374
row.put("float_vector", CommonUtils.generateFloatVector(DIM));
375375
row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
376-
row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array());
376+
row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
377377
row.put("sparse_vector", CommonUtils.generateSparseVector());
378378

379379
// array field
@@ -405,7 +405,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
405405
// vector field
406406
row.put("float_vector", CommonUtils.generateFloatVector(DIM));
407407
row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
408-
row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array());
408+
row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
409409
row.put("sparse_vector", CommonUtils.generateSparseVector());
410410

411411
// array field
@@ -450,7 +450,7 @@ private static List<JsonObject> genImportData(List<Map<String, Object>> original
450450
// vector field
451451
rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(row.get("float_vector")));
452452
rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(row.get("binary_vector")));
453-
rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(row.get("float16_vector")));
453+
rowObject.add("int8_vector", GSON_INSTANCE.toJsonTree(row.get("int8_vector")));
454454
rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(row.get("sparse_vector")));
455455

456456
// array field
@@ -791,7 +791,7 @@ private static void verifyImportData(CreateCollectionReq.CollectionSchema collec
791791

792792
comparePrint(collectionSchema, originalEntity, fetchedEntity, "float_vector");
793793
comparePrint(collectionSchema, originalEntity, fetchedEntity, "binary_vector");
794-
comparePrint(collectionSchema, originalEntity, fetchedEntity, "float16_vector");
794+
comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8_vector");
795795
comparePrint(collectionSchema, originalEntity, fetchedEntity, "sparse_vector");
796796

797797
System.out.println(fetchedEntity);
@@ -815,9 +815,9 @@ private static void createIndex() {
815815
.metricType(IndexParam.MetricType.HAMMING)
816816
.build());
817817
indexes.add(IndexParam.builder()
818-
.fieldName("float16_vector")
819-
.indexType(IndexParam.IndexType.FLAT)
820-
.metricType(IndexParam.MetricType.IP)
818+
.fieldName("int8_vector")
819+
.indexType(IndexParam.IndexType.AUTOINDEX)
820+
.metricType(IndexParam.MetricType.L2)
821821
.build());
822822
indexes.add(IndexParam.builder()
823823
.fieldName("sparse_vector")
@@ -992,8 +992,8 @@ private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
992992
.dimension(DIM)
993993
.build());
994994
schemaV2.addField(AddFieldReq.builder()
995-
.fieldName("float16_vector")
996-
.dataType(DataType.Float16Vector)
995+
.fieldName("int8_vector")
996+
.dataType(DataType.Int8Vector)
997997
.dimension(DIM)
998998
.build());
999999
schemaV2.addField(AddFieldReq.builder()

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ protected Map<String, Object> verifyRow(JsonObject row) {
289289
case FloatVector:
290290
case Float16Vector:
291291
case BFloat16Vector:
292-
case SparseFloatVector: {
292+
case SparseFloatVector:
293+
case Int8Vector:{
293294
Pair<Object, Integer> objectAndSize = verifyVector(obj, field);
294295
rowValues.put(fieldName, objectAndSize.getLeft());
295296
rowSize += objectAndSize.getRight();
@@ -368,6 +369,7 @@ private Pair<Object, Integer> verifyVector(JsonElement object, CreateCollectionR
368369
case FloatVector:
369370
return Pair.of(vector, ((List<?>) vector).size() * 4);
370371
case BinaryVector:
372+
case Int8Vector:
371373
return Pair.of(vector, ((ByteBuffer)vector).limit());
372374
case Float16Vector:
373375
case BFloat16Vector:

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSc
8181
case BinaryVector:
8282
case Float16Vector:
8383
case BFloat16Vector:
84+
case Int8Vector:
85+
boolean isSigned = (field.getDataType() == io.milvus.v2.common.DataType.Int8Vector);
8486
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
85-
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false), field, true);
87+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, isSigned), field, true);
8688
break;
8789
case Array:
8890
fillArrayType(messageTypeBuilder, field);

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/writer/ParquetFileWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ private void appendGroup(Group group, String paramName, Object value, CreateColl
144144
case BinaryVector:
145145
case Float16Vector:
146146
case BFloat16Vector:
147+
case Int8Vector:
147148
addBinaryVector(group, paramName, (ByteBuffer) value);
148149
break;
149150
case SparseFloatVector:

sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ private static CreateCollectionReq.CollectionSchema buildV2Schema(boolean enable
202202
.fieldName("sparse_vector_field")
203203
.dataType(DataType.SparseFloatVector)
204204
.build());
205+
schemaV2.addField(AddFieldReq.builder()
206+
.fieldName("int8_vector_field")
207+
.dataType(DataType.Int8Vector)
208+
.dimension(DIMENSION)
209+
.build());
205210
return schemaV2;
206211
}
207212

@@ -274,6 +279,7 @@ private static List<JsonObject> buildData(int rowCount, boolean isEnableDynamicF
274279
rowObject.add("float_vector_field", JsonUtils.toJsonTree(utils.generateFloatVector()));
275280
rowObject.add("binary_vector_field", JsonUtils.toJsonTree(utils.generateBinaryVector().array()));
276281
rowObject.add("sparse_vector_field", JsonUtils.toJsonTree(utils.generateSparseVector()));
282+
rowObject.add("int8_vector_field", JsonUtils.toJsonTree(utils.generateInt8Vector().array()));
277283

278284
rows.add(rowObject);
279285
}
@@ -368,6 +374,7 @@ void testAppend() {
368374
rowObject.add("float_vector_field", JsonUtils.toJsonTree(utils.generateFloatVector()));
369375
rowObject.add("binary_vector_field", JsonUtils.toJsonTree(utils.generateBinaryVector().array()));
370376
rowObject.add("sparse_vector_field", JsonUtils.toJsonTree(utils.generateSparseVector()));
377+
rowObject.add("int8_vector_field", JsonUtils.toJsonTree(utils.generateInt8Vector().array()));
371378
rowObject.add("arr_int32_field", JsonUtils.toJsonTree(GeneratorUtils.generatorInt32Value(2)));
372379
rowObject.add("arr_float_field", JsonUtils.toJsonTree(GeneratorUtils.generatorFloatValue(3)));
373380
rowObject.add("arr_varchar_field", JsonUtils.toJsonTree(GeneratorUtils.generatorVarcharValue(4, 5)));
@@ -421,6 +428,12 @@ void testAppend() {
421428
// set incorrect type for varchar field, expect throwing an exception
422429
rowObject.addProperty("float_field", 2.5);
423430
rowObject.addProperty("varchar_field", 2.5);
431+
// localBulkWriter.appendRow(rowObject);
432+
Assertions.assertThrows(MilvusException.class, ()->localBulkWriter.appendRow(rowObject));
433+
434+
// set incorrect value type for int8 vector field, expect throwing an exception
435+
rowObject.addProperty("varchar_field", "dummy");
436+
rowObject.addProperty("int8_vector_field", Boolean.TRUE);
424437
// localBulkWriter.appendRow(rowObject);
425438
Assertions.assertThrows(MilvusException.class, ()->localBulkWriter.appendRow(rowObject));
426439
} catch (Exception e) {

sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/TestUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ public List<ByteBuffer> generateBinaryVectors(int count) {
5858

5959
}
6060

61+
public ByteBuffer generateInt8Vector() {
62+
return generateBinaryVector(dimension*8);
63+
}
64+
6165
public ByteBuffer generateFloat16Vector() {
6266
List<Float> vector = generateFloatVector();
6367
return Float16Utils.f32VectorToFp16Buffer(vector);

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ private static HashMap<DataType, String> getTypeErrorMsgForRowInsert() {
9090
typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be JsonArray of byte[].");
9191
typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be JsonArray of byte[].");
9292
typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be JsonObject of Map<Long, Float>.");
93+
typeErrMsg.put(DataType.Int8Vector, "Type mismatch for field '%s': Int8Vector vector field's value type must be JsonArray of byte[].");
9394
return typeErrMsg;
9495
}
9596

0 commit comments

Comments
 (0)