diff --git a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java index 800f1ad69..0cfff8a7e 100644 --- a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java +++ b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java @@ -1138,17 +1138,20 @@ private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel, } } - public static boolean isVectorDataType(DataType dataType) { + public static boolean isDenseVectorDataType(DataType dataType) { Set vectorDataType = new HashSet() {{ add(DataType.FloatVector); add(DataType.BinaryVector); add(DataType.Float16Vector); add(DataType.BFloat16Vector); - add(DataType.SparseFloatVector); }}; return vectorDataType.contains(dataType); } + public static boolean isVectorDataType(DataType dataType) { + return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector; + } + public static FieldData genFieldData(FieldType fieldType, List objects) { return genFieldData(fieldType, objects, Boolean.FALSE); } diff --git a/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java b/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java index 2153249bc..139ce84a6 100644 --- a/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java +++ b/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java @@ -76,6 +76,8 @@ public int getDim() throws IllegalResponseException { } // this method returns bytes size of each vector according to vector type + // for binary vector, each dimension is one bit, each byte is 8 dim + // for float16 vector, each dimension 2 bytes private int checkDim(DataType dt, ByteString data, int dim) { if (dt == DataType.BinaryVector) { if ((data.size()*8) % dim != 0) { @@ -96,6 +98,21 @@ private int checkDim(DataType dt, ByteString data, int dim) { return 0; } + private ByteString getVectorBytes(FieldData fieldData, DataType dt) { + ByteString data; + if (dt == DataType.BinaryVector) { + data = fieldData.getVectors().getBinaryVector(); + } else if (dt == DataType.Float16Vector) { + data = fieldData.getVectors().getFloat16Vector(); + } else if (dt == DataType.BFloat16Vector) { + data = fieldData.getVectors().getBfloat16Vector(); + } else { + String msg = String.format("Unsupported data type %s returned by FieldData", dt.name()); + throw new IllegalResponseException(msg); + } + return data; + } + /** * Gets the row count of a field. * * Throws {@link IllegalResponseException} if the field type is illegal. @@ -116,20 +133,11 @@ public long getRowCount() throws IllegalResponseException { return data.size()/dim; } - case BinaryVector: { - // for binary vector, each dimension is one bit, each byte is 8 dim - int dim = getDim(); - ByteString data = fieldData.getVectors().getBinaryVector(); - int bytePerVec = checkDim(dt, data, dim); - - return data.size()/bytePerVec; - } + case BinaryVector: case Float16Vector: case BFloat16Vector: { - // for float16 vector, each dimension 2 bytes int dim = getDim(); - ByteString data = (dt == DataType.Float16Vector) ? - fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector(); + ByteString data = getVectorBytes(fieldData, dt); int bytePerVec = checkDim(dt, data, dim); return data.size()/bytePerVec; @@ -213,22 +221,14 @@ private List getFieldDataInternal() throws IllegalResponseException { case Float16Vector: case BFloat16Vector: { int dim = getDim(); - ByteString data = null; - if (dt == DataType.BinaryVector) { - data = fieldData.getVectors().getBinaryVector(); - } else if (dt == DataType.Float16Vector) { - data = fieldData.getVectors().getFloat16Vector(); - } else { - data = fieldData.getVectors().getBfloat16Vector(); - } - + ByteString data = getVectorBytes(fieldData, dt); int bytePerVec = checkDim(dt, data, dim); int count = data.size()/bytePerVec; List packData = new ArrayList<>(); for (int i = 0; i < count; ++i) { ByteBuffer bf = ByteBuffer.allocate(bytePerVec); // binary vector doesn't care endian since each byte is independent - // fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes, + // fp16/bf16 vector is sensitive to endian because each dim occupies 2 bytes, // milvus server stores fp16/bf16 vector as little endian bf.order(ByteOrder.LITTLE_ENDIAN); bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray()); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java index 94ff5f08e..59f745908 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java @@ -44,7 +44,7 @@ public class AddFieldReq { @Builder.Default private Boolean autoID = Boolean.FALSE; private Integer dimension; - private io.milvus.v2.common.DataType elementType; + private DataType elementType; private Integer maxCapacity; @Builder.Default private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields) diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java index 05a2dd62c..ea3d7b58b 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java @@ -20,6 +20,7 @@ package io.milvus.v2.service.collection.request; import io.milvus.common.clientenum.FunctionType; +import io.milvus.param.ParamUtils; import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.DataType; import io.milvus.v2.common.IndexParam; @@ -166,8 +167,7 @@ public CollectionSchema addField(AddFieldReq addFieldReq) { fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity()); } else if (addFieldReq.getDataType().equals(DataType.VarChar)) { fieldSchema.setMaxLength(addFieldReq.getMaxLength()); - } else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) || - addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) { + } else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) { if (addFieldReq.getDimension() == null) { throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field"); } diff --git a/sdk-core/src/test/java/io/milvus/TestUtils.java b/sdk-core/src/test/java/io/milvus/TestUtils.java index d1cccfd8b..1948eec41 100644 --- a/sdk-core/src/test/java/io/milvus/TestUtils.java +++ b/sdk-core/src/test/java/io/milvus/TestUtils.java @@ -11,6 +11,8 @@ public class TestUtils { private int dimension = 256; private static final Random RANDOM = new Random(); + public static final String MilvusDockerImageID = "milvusdb/milvus:v2.5.11"; + public TestUtils(int dimension) { this.dimension = dimension; } diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java index b2710385a..48809212e 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -75,7 +75,7 @@ class MilvusClientDockerTest { private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11"); + private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID); @BeforeAll public static void setUp() { diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 5cc644170..2fa0775cb 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -81,7 +81,7 @@ class MilvusClientV2DockerTest { private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11"); + private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID); @BeforeAll public static void setUp() {