Skip to content

Commit e2338cd

Browse files
committed
Fix a bug of nullable Array type
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 0ac1c8e commit e2338cd

4 files changed

Lines changed: 58 additions & 16 deletions

File tree

examples/src/main/java/io/milvus/v2/NullAndDefaultExample.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ private static void queryWithExpr(MilvusClientV2 client, String expr) {
2828
QueryResp queryRet = client.query(QueryReq.builder()
2929
.collectionName(COLLECTION_NAME)
3030
.filter(expr)
31-
.outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default"))
31+
.outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default", "nullable_array"))
3232
.build());
3333
System.out.println("\nQuery with expression: " + expr);
3434
List<QueryResp.QueryResult> records = queryRet.getQueryResults();
@@ -81,6 +81,14 @@ public static void main(String[] args) {
8181
.isNullable(true)
8282
.defaultValue("I am default value")
8383
.build());
84+
collectionSchema.addField(AddFieldReq.builder()
85+
.fieldName("nullable_array")
86+
.dataType(DataType.Array)
87+
.elementType(DataType.VarChar)
88+
.maxCapacity(10)
89+
.maxLength(100)
90+
.isNullable(true)
91+
.build());
8492

8593
List<IndexParam> indexes = new ArrayList<>();
8694
indexes.add(IndexParam.builder()
@@ -111,6 +119,9 @@ public static void main(String[] args) {
111119
row.addProperty("nullable_test", i);
112120
} else {
113121
row.add("nullable_test", JsonNull.INSTANCE);
122+
123+
List<String> arr = Arrays.asList("A", "B", "C");
124+
row.add("nullable_array", gson.toJsonTree(arr));
114125
}
115126

116127
// some values are default value

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,40 +1329,40 @@ private static ScalarField genScalarField(DataType dataType, List<?> objects) {
13291329
case UNRECOGNIZED:
13301330
throw new ParamException("Cannot support this dataType:" + dataType);
13311331
case Int64: {
1332-
List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
1332+
List<Long> longs = objects.stream().map(p -> (p == null) ? null : (Long) p).collect(Collectors.toList());
13331333
LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
13341334
return ScalarField.newBuilder().setLongData(longArray).build();
13351335
}
13361336
case Int32:
13371337
case Int16:
13381338
case Int8: {
1339-
List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
1339+
List<Integer> integers = objects.stream().map(p -> (p == null) ? null : (p instanceof Short ? ((Short) p).intValue() : (Integer) p)).collect(Collectors.toList());
13401340
IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
13411341
return ScalarField.newBuilder().setIntData(intArray).build();
13421342
}
13431343
case Bool: {
1344-
List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
1344+
List<Boolean> booleans = objects.stream().map(p -> (p == null) ? null : (Boolean) p).collect(Collectors.toList());
13451345
BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
13461346
return ScalarField.newBuilder().setBoolData(boolArray).build();
13471347
}
13481348
case Float: {
1349-
List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
1349+
List<Float> floats = objects.stream().map(p -> (p == null) ? null : (Float) p).collect(Collectors.toList());
13501350
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
13511351
return ScalarField.newBuilder().setFloatData(floatArray).build();
13521352
}
13531353
case Double: {
1354-
List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
1354+
List<Double> doubles = objects.stream().map(p -> (p == null) ? null : (Double) p).collect(Collectors.toList());
13551355
DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
13561356
return ScalarField.newBuilder().setDoubleData(doubleArray).build();
13571357
}
13581358
case String:
13591359
case VarChar: {
1360-
List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
1360+
List<String> strings = objects.stream().map(p -> (p == null) ? null : (String) p).collect(Collectors.toList());
13611361
StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
13621362
return ScalarField.newBuilder().setStringData(stringArray).build();
13631363
}
13641364
case JSON: {
1365-
List<ByteString> byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(p.toString()))
1365+
List<ByteString> byteStrings = objects.stream().map(p -> (p == null) ? null : ByteString.copyFromUtf8(p.toString()))
13661366
.collect(Collectors.toList());
13671367
JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
13681368
return ScalarField.newBuilder().setJsonData(jsonArray).build();

sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
251251
return packData;
252252
}
253253
case Array:
254-
List<List<?>> array = new ArrayList<>();
255-
ArrayArray arrArray = fieldData.getScalars().getArrayData();
256-
for (int i = 0; i < arrArray.getDataCount(); i++) {
257-
ScalarField scalar = arrArray.getData(i);
258-
array.add(getScalarData(arrArray.getElementType(), scalar, null));
259-
}
260-
return array;
261254
case Int64:
262255
case Int32:
263256
case Int16:
@@ -308,6 +301,19 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
308301
case JSON:
309302
List<ByteString> dataList = scalar.getJsonData().getDataList();
310303
return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
304+
case Array:
305+
List<List<?>> array = new ArrayList<>();
306+
ArrayArray arrArray = fieldData.getScalars().getArrayData();
307+
boolean nullable = validData != null && validData.size() == arrArray.getDataCount();
308+
for (int i = 0; i < arrArray.getDataCount(); i++) {
309+
if (nullable && validData.get(i) == Boolean.FALSE) {
310+
array.add(null);
311+
} else {
312+
ScalarField rowData = arrArray.getData(i);
313+
array.add(getScalarData(arrArray.getElementType(), rowData, null));
314+
}
315+
}
316+
return array;
311317
default:
312318
return new ArrayList<>();
313319
}

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,13 @@ void testNullableAndDefaultValue() {
18901890
.isNullable(Boolean.TRUE)
18911891
.maxLength(100)
18921892
.build());
1893+
collectionSchema.addField(AddFieldReq.builder()
1894+
.fieldName("arr")
1895+
.dataType(DataType.Array)
1896+
.elementType(DataType.Int32)
1897+
.isNullable(Boolean.TRUE)
1898+
.maxCapacity(100)
1899+
.build());
18931900

18941901
List<IndexParam> indexParams = new ArrayList<>();
18951902
indexParams.add(IndexParam.builder()
@@ -1918,7 +1925,11 @@ void testNullableAndDefaultValue() {
19181925
} else {
19191926
// row.add("flag", JsonNull.INSTANCE);
19201927
row.addProperty("desc", "AAA");
1928+
1929+
List<Integer> arr = Arrays.asList(5, 6);
1930+
row.add("arr", JsonUtils.toJsonTree(arr));
19211931
}
1932+
19221933
data.add(row);
19231934
}
19241935

@@ -1932,7 +1943,7 @@ void testNullableAndDefaultValue() {
19321943
QueryResp queryResp = client.query(QueryReq.builder()
19331944
.collectionName(randomCollectionName)
19341945
.filter("id >= 0")
1935-
.outputFields(Arrays.asList("desc", "flag"))
1946+
.outputFields(Arrays.asList("desc", "flag", "arr"))
19361947
.consistencyLevel(ConsistencyLevel.STRONG)
19371948
.build());
19381949
List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
@@ -1944,9 +1955,16 @@ void testNullableAndDefaultValue() {
19441955
if (id%2 == 0) {
19451956
Assertions.assertEquals((int)id, entity.get("flag"));
19461957
Assertions.assertNull(entity.get("desc"));
1958+
Assertions.assertNull(entity.get("arr"));
19471959
} else {
19481960
Assertions.assertEquals(10, entity.get("flag"));
19491961
Assertions.assertEquals("AAA", entity.get("desc"));
1962+
Object obj = entity.get("arr");
1963+
Assertions.assertInstanceOf(List.class, obj);
1964+
List<Integer> arr = (List<Integer>)obj;
1965+
Assertions.assertEquals(2, arr.size());
1966+
Assertions.assertEquals(5, arr.get(0));
1967+
Assertions.assertEquals(6, arr.get(1));
19501968
}
19511969
System.out.println(result);
19521970
}
@@ -1971,9 +1989,16 @@ void testNullableAndDefaultValue() {
19711989
if (id%2 == 0) {
19721990
Assertions.assertEquals((int)id, entity.get("flag"));
19731991
Assertions.assertNull(entity.get("desc"));
1992+
Assertions.assertNull(entity.get("arr"));
19741993
} else {
19751994
Assertions.assertEquals(10, entity.get("flag"));
19761995
Assertions.assertEquals("AAA", entity.get("desc"));
1996+
Object obj = entity.get("arr");
1997+
Assertions.assertInstanceOf(List.class, obj);
1998+
List<Integer> arr = (List<Integer>)obj;
1999+
Assertions.assertEquals(2, arr.size());
2000+
Assertions.assertEquals(5, arr.get(0));
2001+
Assertions.assertEquals(6, arr.get(1));
19772002
}
19782003
System.out.println(result);
19792004
}

0 commit comments

Comments
 (0)