Skip to content

Commit 06eb1dc

Browse files
authored
Refine code (#1409)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 709c2df commit 06eb1dc

7 files changed

Lines changed: 33 additions & 28 deletions

File tree

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,17 +1138,20 @@ private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel,
11381138
}
11391139
}
11401140

1141-
public static boolean isVectorDataType(DataType dataType) {
1141+
public static boolean isDenseVectorDataType(DataType dataType) {
11421142
Set<DataType> vectorDataType = new HashSet<DataType>() {{
11431143
add(DataType.FloatVector);
11441144
add(DataType.BinaryVector);
11451145
add(DataType.Float16Vector);
11461146
add(DataType.BFloat16Vector);
1147-
add(DataType.SparseFloatVector);
11481147
}};
11491148
return vectorDataType.contains(dataType);
11501149
}
11511150

1151+
public static boolean isVectorDataType(DataType dataType) {
1152+
return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
1153+
}
1154+
11521155
public static FieldData genFieldData(FieldType fieldType, List<?> objects) {
11531156
return genFieldData(fieldType, objects, Boolean.FALSE);
11541157
}

sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ public int getDim() throws IllegalResponseException {
7676
}
7777

7878
// this method returns bytes size of each vector according to vector type
79+
// for binary vector, each dimension is one bit, each byte is 8 dim
80+
// for float16 vector, each dimension 2 bytes
7981
private int checkDim(DataType dt, ByteString data, int dim) {
8082
if (dt == DataType.BinaryVector) {
8183
if ((data.size()*8) % dim != 0) {
@@ -96,6 +98,21 @@ private int checkDim(DataType dt, ByteString data, int dim) {
9698
return 0;
9799
}
98100

101+
private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
102+
ByteString data;
103+
if (dt == DataType.BinaryVector) {
104+
data = fieldData.getVectors().getBinaryVector();
105+
} else if (dt == DataType.Float16Vector) {
106+
data = fieldData.getVectors().getFloat16Vector();
107+
} else if (dt == DataType.BFloat16Vector) {
108+
data = fieldData.getVectors().getBfloat16Vector();
109+
} else {
110+
String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
111+
throw new IllegalResponseException(msg);
112+
}
113+
return data;
114+
}
115+
99116
/**
100117
* Gets the row count of a field.
101118
* * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -116,20 +133,11 @@ public long getRowCount() throws IllegalResponseException {
116133

117134
return data.size()/dim;
118135
}
119-
case BinaryVector: {
120-
// for binary vector, each dimension is one bit, each byte is 8 dim
121-
int dim = getDim();
122-
ByteString data = fieldData.getVectors().getBinaryVector();
123-
int bytePerVec = checkDim(dt, data, dim);
124-
125-
return data.size()/bytePerVec;
126-
}
136+
case BinaryVector:
127137
case Float16Vector:
128138
case BFloat16Vector: {
129-
// for float16 vector, each dimension 2 bytes
130139
int dim = getDim();
131-
ByteString data = (dt == DataType.Float16Vector) ?
132-
fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
140+
ByteString data = getVectorBytes(fieldData, dt);
133141
int bytePerVec = checkDim(dt, data, dim);
134142

135143
return data.size()/bytePerVec;
@@ -213,22 +221,14 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
213221
case Float16Vector:
214222
case BFloat16Vector: {
215223
int dim = getDim();
216-
ByteString data = null;
217-
if (dt == DataType.BinaryVector) {
218-
data = fieldData.getVectors().getBinaryVector();
219-
} else if (dt == DataType.Float16Vector) {
220-
data = fieldData.getVectors().getFloat16Vector();
221-
} else {
222-
data = fieldData.getVectors().getBfloat16Vector();
223-
}
224-
224+
ByteString data = getVectorBytes(fieldData, dt);
225225
int bytePerVec = checkDim(dt, data, dim);
226226
int count = data.size()/bytePerVec;
227227
List<ByteBuffer> packData = new ArrayList<>();
228228
for (int i = 0; i < count; ++i) {
229229
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
230230
// binary vector doesn't care endian since each byte is independent
231-
// fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
231+
// fp16/bf16 vector is sensitive to endian because each dim occupies 2 bytes,
232232
// milvus server stores fp16/bf16 vector as little endian
233233
bf.order(ByteOrder.LITTLE_ENDIAN);
234234
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());

sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class AddFieldReq {
4444
@Builder.Default
4545
private Boolean autoID = Boolean.FALSE;
4646
private Integer dimension;
47-
private io.milvus.v2.common.DataType elementType;
47+
private DataType elementType;
4848
private Integer maxCapacity;
4949
@Builder.Default
5050
private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)

sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package io.milvus.v2.service.collection.request;
2121

2222
import io.milvus.common.clientenum.FunctionType;
23+
import io.milvus.param.ParamUtils;
2324
import io.milvus.v2.common.ConsistencyLevel;
2425
import io.milvus.v2.common.DataType;
2526
import io.milvus.v2.common.IndexParam;
@@ -166,8 +167,7 @@ public CollectionSchema addField(AddFieldReq addFieldReq) {
166167
fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
167168
} else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
168169
fieldSchema.setMaxLength(addFieldReq.getMaxLength());
169-
} else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) ||
170-
addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) {
170+
} else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) {
171171
if (addFieldReq.getDimension() == null) {
172172
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
173173
}

sdk-core/src/test/java/io/milvus/TestUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ public class TestUtils {
1111
private int dimension = 256;
1212
private static final Random RANDOM = new Random();
1313

14+
public static final String MilvusDockerImageID = "milvusdb/milvus:v2.5.11";
15+
1416
public TestUtils(int dimension) {
1517
this.dimension = dimension;
1618
}

sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class MilvusClientDockerTest {
7575
private static final TestUtils utils = new TestUtils(DIMENSION);
7676

7777
@Container
78-
private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
78+
private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);
7979

8080
@BeforeAll
8181
public static void setUp() {

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class MilvusClientV2DockerTest {
8181
private static final TestUtils utils = new TestUtils(DIMENSION);
8282

8383
@Container
84-
private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
84+
private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);
8585

8686
@BeforeAll
8787
public static void setUp() {

0 commit comments

Comments
 (0)