Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sdk-core/src/main/java/io/milvus/param/ParamUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -1138,17 +1138,20 @@ private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel,
}
}

public static boolean isVectorDataType(DataType dataType) {
public static boolean isDenseVectorDataType(DataType dataType) {
Set<DataType> vectorDataType = new HashSet<DataType>() {{
add(DataType.FloatVector);
add(DataType.BinaryVector);
add(DataType.Float16Vector);
add(DataType.BFloat16Vector);
add(DataType.SparseFloatVector);
}};
return vectorDataType.contains(dataType);
}

public static boolean isVectorDataType(DataType dataType) {
return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
}

public static FieldData genFieldData(FieldType fieldType, List<?> objects) {
return genFieldData(fieldType, objects, Boolean.FALSE);
}
Expand Down
42 changes: 21 additions & 21 deletions sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ public int getDim() throws IllegalResponseException {
}

// this method returns bytes size of each vector according to vector type
// for binary vector, each dimension is one bit, each byte is 8 dim
// for float16 vector, each dimension 2 bytes
private int checkDim(DataType dt, ByteString data, int dim) {
if (dt == DataType.BinaryVector) {
if ((data.size()*8) % dim != 0) {
Expand All @@ -96,6 +98,21 @@ private int checkDim(DataType dt, ByteString data, int dim) {
return 0;
}

private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
ByteString data;
if (dt == DataType.BinaryVector) {
data = fieldData.getVectors().getBinaryVector();
} else if (dt == DataType.Float16Vector) {
data = fieldData.getVectors().getFloat16Vector();
} else if (dt == DataType.BFloat16Vector) {
data = fieldData.getVectors().getBfloat16Vector();
} else {
String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
throw new IllegalResponseException(msg);
}
return data;
}

/**
* Gets the row count of a field.
* * Throws {@link IllegalResponseException} if the field type is illegal.
Expand All @@ -116,20 +133,11 @@ public long getRowCount() throws IllegalResponseException {

return data.size()/dim;
}
case BinaryVector: {
// for binary vector, each dimension is one bit, each byte is 8 dim
int dim = getDim();
ByteString data = fieldData.getVectors().getBinaryVector();
int bytePerVec = checkDim(dt, data, dim);

return data.size()/bytePerVec;
}
case BinaryVector:
case Float16Vector:
case BFloat16Vector: {
// for float16 vector, each dimension 2 bytes
int dim = getDim();
ByteString data = (dt == DataType.Float16Vector) ?
fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
ByteString data = getVectorBytes(fieldData, dt);
int bytePerVec = checkDim(dt, data, dim);

return data.size()/bytePerVec;
Expand Down Expand Up @@ -213,22 +221,14 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
case Float16Vector:
case BFloat16Vector: {
int dim = getDim();
ByteString data = null;
if (dt == DataType.BinaryVector) {
data = fieldData.getVectors().getBinaryVector();
} else if (dt == DataType.Float16Vector) {
data = fieldData.getVectors().getFloat16Vector();
} else {
data = fieldData.getVectors().getBfloat16Vector();
}

ByteString data = getVectorBytes(fieldData, dt);
int bytePerVec = checkDim(dt, data, dim);
int count = data.size()/bytePerVec;
List<ByteBuffer> packData = new ArrayList<>();
for (int i = 0; i < count; ++i) {
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
// binary vector doesn't care endian since each byte is independent
// fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
// fp16/bf16 vector is sensitive to endian because each dim occupies 2 bytes,
// milvus server stores fp16/bf16 vector as little endian
bf.order(ByteOrder.LITTLE_ENDIAN);
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class AddFieldReq {
@Builder.Default
private Boolean autoID = Boolean.FALSE;
private Integer dimension;
private io.milvus.v2.common.DataType elementType;
private DataType elementType;
private Integer maxCapacity;
@Builder.Default
private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package io.milvus.v2.service.collection.request;

import io.milvus.common.clientenum.FunctionType;
import io.milvus.param.ParamUtils;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
Expand Down Expand Up @@ -166,8 +167,7 @@ public CollectionSchema addField(AddFieldReq addFieldReq) {
fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
} else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
fieldSchema.setMaxLength(addFieldReq.getMaxLength());
} else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) ||
addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) {
} else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) {
if (addFieldReq.getDimension() == null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
}
Expand Down
2 changes: 2 additions & 0 deletions sdk-core/src/test/java/io/milvus/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public class TestUtils {
private int dimension = 256;
private static final Random RANDOM = new Random();

public static final String MilvusDockerImageID = "milvusdb/milvus:v2.5.11";

public TestUtils(int dimension) {
this.dimension = dimension;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MilvusClientDockerTest {
private static final TestUtils utils = new TestUtils(DIMENSION);

@Container
private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);

@BeforeAll
public static void setUp() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MilvusClientV2DockerTest {
private static final TestUtils utils = new TestUtils(DIMENSION);

@Container
private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);

@BeforeAll
public static void setUp() {
Expand Down
Loading