Skip to content

Commit 65abbab

Browse files
authored
Support Int8Vector(feature of v2.6) (#1408)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 69d2eb7 commit 65abbab

11 files changed

Lines changed: 425 additions & 38 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package io.milvus.v2;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.JsonObject;
5+
import io.milvus.v1.CommonUtils;
6+
import io.milvus.v2.client.ConnectConfig;
7+
import io.milvus.v2.client.MilvusClientV2;
8+
import io.milvus.v2.common.ConsistencyLevel;
9+
import io.milvus.v2.common.DataType;
10+
import io.milvus.v2.common.IndexParam;
11+
import io.milvus.v2.service.collection.request.AddFieldReq;
12+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
13+
import io.milvus.v2.service.collection.request.DropCollectionReq;
14+
import io.milvus.v2.service.vector.request.InsertReq;
15+
import io.milvus.v2.service.vector.request.QueryReq;
16+
import io.milvus.v2.service.vector.request.SearchReq;
17+
import io.milvus.v2.service.vector.request.data.BinaryVec;
18+
import io.milvus.v2.service.vector.request.data.Int8Vec;
19+
import io.milvus.v2.service.vector.response.QueryResp;
20+
import io.milvus.v2.service.vector.response.SearchResp;
21+
22+
import java.nio.ByteBuffer;
23+
import java.util.*;
24+
25+
public class Int8VectorExample {
26+
private static final String COLLECTION_NAME = "java_sdk_example_int8_vector_v2";
27+
private static final String ID_FIELD = "id";
28+
private static final String VECTOR_FIELD = "vector";
29+
30+
private static final Integer VECTOR_DIM = 128;
31+
32+
private static List<ByteBuffer> generateInt8Vectors(int count) {
33+
Random RANDOM = new Random();
34+
List<ByteBuffer> vectors = new ArrayList<>();
35+
for (int i = 0; i < count; i++) {
36+
ByteBuffer vector = ByteBuffer.allocate(VECTOR_DIM);
37+
for (int k = 0; k < VECTOR_DIM; ++k) {
38+
vector.put((byte) (RANDOM.nextInt(256) - 128));
39+
}
40+
vectors.add(vector);
41+
}
42+
43+
return vectors;
44+
}
45+
46+
47+
public static void main(String[] args) {
48+
ConnectConfig config = ConnectConfig.builder()
49+
.uri("http://localhost:19530")
50+
.build();
51+
MilvusClientV2 client = new MilvusClientV2(config);
52+
53+
// Drop collection if exists
54+
client.dropCollection(DropCollectionReq.builder()
55+
.collectionName(COLLECTION_NAME)
56+
.build());
57+
58+
// Create collection
59+
CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
60+
.build();
61+
collectionSchema.addField(AddFieldReq.builder()
62+
.fieldName(ID_FIELD)
63+
.dataType(DataType.Int64)
64+
.isPrimaryKey(Boolean.TRUE)
65+
.build());
66+
collectionSchema.addField(AddFieldReq.builder()
67+
.fieldName(VECTOR_FIELD)
68+
.dataType(DataType.Int8Vector)
69+
.dimension(VECTOR_DIM)
70+
.build());
71+
72+
List<IndexParam> indexes = new ArrayList<>();
73+
Map<String,Object> extraParams = new HashMap<>();
74+
extraParams.put("M", 64);
75+
extraParams.put("efConstruction", 200);
76+
indexes.add(IndexParam.builder()
77+
.fieldName(VECTOR_FIELD)
78+
.indexType(IndexParam.IndexType.HNSW)
79+
.metricType(IndexParam.MetricType.L2)
80+
.extraParams(extraParams)
81+
.build());
82+
83+
CreateCollectionReq requestCreate = CreateCollectionReq.builder()
84+
.collectionName(COLLECTION_NAME)
85+
.collectionSchema(collectionSchema)
86+
.indexParams(indexes)
87+
.consistencyLevel(ConsistencyLevel.BOUNDED)
88+
.build();
89+
client.createCollection(requestCreate);
90+
System.out.println("Collection created");
91+
92+
// Insert entities by rows
93+
int rowCount = 10000;
94+
List<ByteBuffer> vectors = generateInt8Vectors(rowCount);
95+
List<JsonObject> rows = new ArrayList<>();
96+
Gson gson = new Gson();
97+
for (long i = 0L; i < rowCount; ++i) {
98+
JsonObject row = new JsonObject();
99+
row.addProperty(ID_FIELD, i);
100+
ByteBuffer vector = vectors.get((int)i);
101+
row.add(VECTOR_FIELD, gson.toJsonTree(vector.array()));
102+
rows.add(row);
103+
}
104+
105+
client.insert(InsertReq.builder()
106+
.collectionName(COLLECTION_NAME)
107+
.data(rows)
108+
.build());
109+
110+
// Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
111+
QueryResp countR = client.query(QueryReq.builder()
112+
.collectionName(COLLECTION_NAME)
113+
.filter("")
114+
.outputFields(Collections.singletonList("count(*)"))
115+
.consistencyLevel(ConsistencyLevel.STRONG)
116+
.build());
117+
System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
118+
119+
// Pick some vectors from the inserted vectors to search
120+
// Ensure the returned top1 item's ID should be equal to target vector's ID
121+
for (int i = 0; i < 10; i++) {
122+
Random ran = new Random();
123+
int k = ran.nextInt(rowCount);
124+
ByteBuffer targetVector = vectors.get(k);
125+
SearchResp searchResp = client.search(SearchReq.builder()
126+
.collectionName(COLLECTION_NAME)
127+
.data(Collections.singletonList(new Int8Vec(targetVector)))
128+
.annsField(VECTOR_FIELD)
129+
.outputFields(Collections.singletonList(VECTOR_FIELD))
130+
.topK(3)
131+
.build());
132+
133+
// The search() allows multiple target vectors to search in a batch.
134+
// Here we only input one vector to search, get the result of No.0 vector to check
135+
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
136+
List<SearchResp.SearchResult> results = searchResults.get(0);
137+
System.out.printf("\nThe result of No.%d vector %s:\n", k, Arrays.toString(targetVector.array()));
138+
for (SearchResp.SearchResult result : results) {
139+
System.out.printf("ID: %d, Score: %f, Vector: ", (long)result.getId(), result.getScore());
140+
ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
141+
System.out.print(Arrays.toString(vector.array()));
142+
System.out.println();
143+
}
144+
145+
SearchResp.SearchResult firstResult = results.get(0);
146+
if ((long)firstResult.getId() != k) {
147+
throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
148+
(long)firstResult.getId(), k));
149+
}
150+
}
151+
System.out.println("Search result is correct");
152+
153+
// Retrieve some data
154+
int n = 99;
155+
QueryResp queryResp = client.query(QueryReq.builder()
156+
.collectionName(COLLECTION_NAME)
157+
.filter(String.format("id == %d", n))
158+
.outputFields(Collections.singletonList(VECTOR_FIELD))
159+
.build());
160+
161+
List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
162+
if (queryResults.isEmpty()) {
163+
throw new RuntimeException("The query result is empty");
164+
} else {
165+
ByteBuffer vector = (ByteBuffer) queryResults.get(0).getEntity().get(VECTOR_FIELD);
166+
if (vector.compareTo(vectors.get(n)) != 0) {
167+
throw new RuntimeException("The query result is incorrect");
168+
}
169+
}
170+
System.out.println("Query result is correct");
171+
172+
173+
// Drop the collection if you don't need the collection anymore
174+
client.dropCollection(DropCollectionReq.builder()
175+
.collectionName(COLLECTION_NAME)
176+
.build());
177+
178+
client.close();
179+
}
180+
}

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ public static void checkFieldData(FieldType fieldSchema, InsertParam.Field field
101101
private static int calculateBinVectorDim(DataType dataType, int byteCount) {
102102
if (dataType == DataType.BinaryVector) {
103103
return byteCount*8; // for BinaryVector, each byte is 8 dimensions
104+
} else if (dataType == DataType.Int8Vector) {
105+
return byteCount; // for Int8Vector, each byte is one dimension
104106
} else {
105107
if (byteCount%2 != 0) {
106108
String msg = "Incorrect byte count for %s type field, byte count is %d, cannot be evenly divided by 2";
@@ -358,7 +360,8 @@ public static Object checkFieldValue(FieldType fieldSchema, JsonElement value) {
358360
}
359361
case BinaryVector:
360362
case Float16Vector:
361-
case BFloat16Vector: {
363+
case BFloat16Vector:
364+
case Int8Vector: {
362365
if (!(value.isJsonArray())) {
363366
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
364367
}
@@ -726,7 +729,7 @@ public static ByteString convertPlaceholder(List<?> vectors, PlaceholderType pla
726729
ByteString bs = ByteString.copyFrom(array);
727730
byteStrings.add(bs);
728731
} else if (vector instanceof ByteBuffer) {
729-
// for fp16/bf16 vector, each vector is a ByteBuffer with little endian
732+
// for fp16/bf16/int8 vector, each vector is a ByteBuffer with little endian
730733
// for binary vector, each vector is a ByteBuffer no matter which endian
731734
// the endian of each ByteBuffer is already specified by the caller
732735
plType = PlaceholderType.BinaryVector;
@@ -1138,17 +1141,21 @@ private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel,
11381141
}
11391142
}
11401143

1141-
public static boolean isVectorDataType(DataType dataType) {
1144+
public static boolean isDenseVectorDataType(DataType dataType) {
11421145
Set<DataType> vectorDataType = new HashSet<DataType>() {{
11431146
add(DataType.FloatVector);
11441147
add(DataType.BinaryVector);
11451148
add(DataType.Float16Vector);
11461149
add(DataType.BFloat16Vector);
1147-
add(DataType.SparseFloatVector);
1150+
add(DataType.Int8Vector);
11481151
}};
11491152
return vectorDataType.contains(dataType);
11501153
}
11511154

1155+
public static boolean isVectorDataType(DataType dataType) {
1156+
return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
1157+
}
1158+
11521159
public static FieldData genFieldData(FieldType fieldType, List<?> objects) {
11531160
return genFieldData(fieldType, objects, Boolean.FALSE);
11541161
}
@@ -1203,10 +1210,11 @@ private static VectorField genVectorField(DataType dataType, List<?> objects) {
12031210
return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
12041211
} else if (dataType == DataType.BinaryVector ||
12051212
dataType == DataType.Float16Vector ||
1206-
dataType == DataType.BFloat16Vector) {
1213+
dataType == DataType.BFloat16Vector ||
1214+
dataType == DataType.Int8Vector) {
12071215
ByteBuffer totalBuf = null;
12081216
int dim = 0;
1209-
// for fp16/bf16 vector, each vector is a ByteBuffer with little endian
1217+
// for fp16/bf16/int8 vector, each vector is a ByteBuffer with little endian
12101218
// for binary vector, each vector is a ByteBuffer no matter which endian
12111219
// no need to set totalBuf endian since it is treated as byte array
12121220
for (Object object : objects) {
@@ -1226,8 +1234,10 @@ private static VectorField genVectorField(DataType dataType, List<?> objects) {
12261234
return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
12271235
} else if (dataType == DataType.Float16Vector) {
12281236
return VectorField.newBuilder().setDim(dim).setFloat16Vector(byteString).build();
1229-
} else {
1237+
} else if (dataType == DataType.BFloat16Vector) {
12301238
return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
1239+
} else {
1240+
return VectorField.newBuilder().setDim(dim).setInt8Vector(byteString).build();
12311241
}
12321242
} else if (dataType == DataType.SparseFloatVector) {
12331243
SparseFloatArray sparseArray = genSparseFloatArray(objects);

sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ public int getDim() throws IllegalResponseException {
7676
}
7777

7878
// this method returns bytes size of each vector according to vector type
79+
// for binary vector, each dimension is one bit, each byte is 8 dim
80+
// for int8 vector, each dimension is ony byte, each byte is one dim
81+
// for float16 vector, each dimension 2 bytes
7982
private int checkDim(DataType dt, ByteString data, int dim) {
8083
if (dt == DataType.BinaryVector) {
8184
if ((data.size()*8) % dim != 0) {
@@ -91,11 +94,35 @@ private int checkDim(DataType dt, ByteString data, int dim) {
9194
throw new IllegalResponseException(msg);
9295
}
9396
return dim*2;
97+
} else if (dt == DataType.Int8Vector) {
98+
if (data.size() % dim != 0) {
99+
String msg = String.format("Returned int8 vector data array size %d doesn't match dimension %d",
100+
data.size(), dim);
101+
throw new IllegalResponseException(msg);
102+
}
103+
return dim;
94104
}
95105

96106
return 0;
97107
}
98108

109+
private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
110+
ByteString data;
111+
if (dt == DataType.BinaryVector) {
112+
data = fieldData.getVectors().getBinaryVector();
113+
} else if (dt == DataType.Float16Vector) {
114+
data = fieldData.getVectors().getFloat16Vector();
115+
} else if (dt == DataType.BFloat16Vector) {
116+
data = fieldData.getVectors().getBfloat16Vector();
117+
} else if (dt == DataType.Int8Vector) {
118+
data = fieldData.getVectors().getInt8Vector();
119+
} else {
120+
String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
121+
throw new IllegalResponseException(msg);
122+
}
123+
return data;
124+
}
125+
99126
/**
100127
* Gets the row count of a field.
101128
* * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -116,20 +143,12 @@ public long getRowCount() throws IllegalResponseException {
116143

117144
return data.size()/dim;
118145
}
119-
case BinaryVector: {
120-
// for binary vector, each dimension is one bit, each byte is 8 dim
121-
int dim = getDim();
122-
ByteString data = fieldData.getVectors().getBinaryVector();
123-
int bytePerVec = checkDim(dt, data, dim);
124-
125-
return data.size()/bytePerVec;
126-
}
146+
case BinaryVector:
127147
case Float16Vector:
128-
case BFloat16Vector: {
129-
// for float16 vector, each dimension 2 bytes
148+
case BFloat16Vector:
149+
case Int8Vector: {
130150
int dim = getDim();
131-
ByteString data = (dt == DataType.Float16Vector) ?
132-
fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
151+
ByteString data = getVectorBytes(fieldData, dt);
133152
int bytePerVec = checkDim(dt, data, dim);
134153

135154
return data.size()/bytePerVec;
@@ -211,25 +230,18 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
211230
}
212231
case BinaryVector:
213232
case Float16Vector:
214-
case BFloat16Vector: {
233+
case BFloat16Vector:
234+
case Int8Vector: {
215235
int dim = getDim();
216-
ByteString data = null;
217-
if (dt == DataType.BinaryVector) {
218-
data = fieldData.getVectors().getBinaryVector();
219-
} else if (dt == DataType.Float16Vector) {
220-
data = fieldData.getVectors().getFloat16Vector();
221-
} else {
222-
data = fieldData.getVectors().getBfloat16Vector();
223-
}
224-
236+
ByteString data = getVectorBytes(fieldData, dt);
225237
int bytePerVec = checkDim(dt, data, dim);
226238
int count = data.size()/bytePerVec;
227239
List<ByteBuffer> packData = new ArrayList<>();
228240
for (int i = 0; i < count; ++i) {
229241
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
230242
// binary vector doesn't care endian since each byte is independent
231-
// fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
232-
// milvus server stores fp16/bf16 vector as little endian
243+
// fp16/bf16/int8 vector is sensitive to endian because each dim occupies 1~2 bytes,
244+
// milvus server stores fp16/bf16/int8 vector as little endian
233245
bf.order(ByteOrder.LITTLE_ENDIAN);
234246
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
235247
packData.add(bf);

sdk-core/src/main/java/io/milvus/v2/common/DataType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ public enum DataType {
4444
FloatVector(101),
4545
Float16Vector(102),
4646
BFloat16Vector(103),
47-
SparseFloatVector(104);
47+
SparseFloatVector(104),
48+
Int8Vector(105);
4849

4950
private final int code;
5051
DataType(int code) {

sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class AddFieldReq {
4444
@Builder.Default
4545
private Boolean autoID = Boolean.FALSE;
4646
private Integer dimension;
47-
private io.milvus.v2.common.DataType elementType;
47+
private DataType elementType;
4848
private Integer maxCapacity;
4949
@Builder.Default
5050
private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)

sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package io.milvus.v2.service.collection.request;
2121

2222
import io.milvus.common.clientenum.FunctionType;
23+
import io.milvus.param.ParamUtils;
2324
import io.milvus.v2.common.ConsistencyLevel;
2425
import io.milvus.v2.common.DataType;
2526
import io.milvus.v2.common.IndexParam;
@@ -166,8 +167,7 @@ public CollectionSchema addField(AddFieldReq addFieldReq) {
166167
fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
167168
} else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
168169
fieldSchema.setMaxLength(addFieldReq.getMaxLength());
169-
} else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) ||
170-
addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) {
170+
} else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) {
171171
if (addFieldReq.getDimension() == null) {
172172
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
173173
}

0 commit comments

Comments
 (0)