Skip to content

Commit cdc8203

Browse files
committed
Fix a bug of nullable Array type
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 0ac1c8e commit cdc8203

4 files changed

Lines changed: 66 additions & 31 deletions

File tree

examples/src/main/java/io/milvus/v2/NullAndDefaultExample.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ private static void queryWithExpr(MilvusClientV2 client, String expr) {
2828
QueryResp queryRet = client.query(QueryReq.builder()
2929
.collectionName(COLLECTION_NAME)
3030
.filter(expr)
31-
.outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default"))
31+
.outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default", "nullable_array"))
3232
.build());
3333
System.out.println("\nQuery with expression: " + expr);
3434
List<QueryResp.QueryResult> records = queryRet.getQueryResults();
@@ -81,6 +81,14 @@ public static void main(String[] args) {
8181
.isNullable(true)
8282
.defaultValue("I am default value")
8383
.build());
84+
collectionSchema.addField(AddFieldReq.builder()
85+
.fieldName("nullable_array")
86+
.dataType(DataType.Array)
87+
.elementType(DataType.VarChar)
88+
.maxCapacity(10)
89+
.maxLength(100)
90+
.isNullable(true)
91+
.build());
8492

8593
List<IndexParam> indexes = new ArrayList<>();
8694
indexes.add(IndexParam.builder()
@@ -111,6 +119,9 @@ public static void main(String[] args) {
111119
row.addProperty("nullable_test", i);
112120
} else {
113121
row.add("nullable_test", JsonNull.INSTANCE);
122+
123+
List<String> arr = Arrays.asList("A", "B", "C");
124+
row.add("nullable_array", gson.toJsonTree(arr));
114125
}
115126

116127
// some values are default value

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,40 +1329,40 @@ private static ScalarField genScalarField(DataType dataType, List<?> objects) {
13291329
case UNRECOGNIZED:
13301330
throw new ParamException("Cannot support this dataType:" + dataType);
13311331
case Int64: {
1332-
List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
1332+
List<Long> longs = objects.stream().map(p -> (p == null) ? null : (Long) p).collect(Collectors.toList());
13331333
LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
13341334
return ScalarField.newBuilder().setLongData(longArray).build();
13351335
}
13361336
case Int32:
13371337
case Int16:
13381338
case Int8: {
1339-
List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
1339+
List<Integer> integers = objects.stream().map(p -> (p == null) ? null : (p instanceof Short ? ((Short) p).intValue() : (Integer) p)).collect(Collectors.toList());
13401340
IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
13411341
return ScalarField.newBuilder().setIntData(intArray).build();
13421342
}
13431343
case Bool: {
1344-
List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
1344+
List<Boolean> booleans = objects.stream().map(p -> (p == null) ? null : (Boolean) p).collect(Collectors.toList());
13451345
BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
13461346
return ScalarField.newBuilder().setBoolData(boolArray).build();
13471347
}
13481348
case Float: {
1349-
List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
1349+
List<Float> floats = objects.stream().map(p -> (p == null) ? null : (Float) p).collect(Collectors.toList());
13501350
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
13511351
return ScalarField.newBuilder().setFloatData(floatArray).build();
13521352
}
13531353
case Double: {
1354-
List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
1354+
List<Double> doubles = objects.stream().map(p -> (p == null) ? null : (Double) p).collect(Collectors.toList());
13551355
DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
13561356
return ScalarField.newBuilder().setDoubleData(doubleArray).build();
13571357
}
13581358
case String:
13591359
case VarChar: {
1360-
List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
1360+
List<String> strings = objects.stream().map(p -> (p == null) ? null : (String) p).collect(Collectors.toList());
13611361
StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
13621362
return ScalarField.newBuilder().setStringData(stringArray).build();
13631363
}
13641364
case JSON: {
1365-
List<ByteString> byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(p.toString()))
1365+
List<ByteString> byteStrings = objects.stream().map(p -> (p == null) ? null : ByteString.copyFromUtf8(p.toString()))
13661366
.collect(Collectors.toList());
13671367
JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
13681368
return ScalarField.newBuilder().setJsonData(jsonArray).build();

sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
251251
return packData;
252252
}
253253
case Array:
254-
List<List<?>> array = new ArrayList<>();
255-
ArrayArray arrArray = fieldData.getScalars().getArrayData();
256-
for (int i = 0; i < arrArray.getDataCount(); i++) {
257-
ScalarField scalar = arrArray.getData(i);
258-
array.add(getScalarData(arrArray.getElementType(), scalar, null));
259-
}
260-
return array;
261254
case Int64:
262255
case Int32:
263256
case Int16:
@@ -308,6 +301,19 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
308301
case JSON:
309302
List<ByteString> dataList = scalar.getJsonData().getDataList();
310303
return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
304+
case Array:
305+
List<List<?>> array = new ArrayList<>();
306+
ArrayArray arrArray = fieldData.getScalars().getArrayData();
307+
boolean nullable = validData != null && validData.size() == arrArray.getDataCount();
308+
for (int i = 0; i < arrArray.getDataCount(); i++) {
309+
if (nullable && validData.get(i) == Boolean.FALSE) {
310+
array.add(null);
311+
} else {
312+
ScalarField rowData = arrArray.getData(i);
313+
array.add(getScalarData(arrArray.getElementType(), rowData, null));
314+
}
315+
}
316+
return array;
311317
default:
312318
return new ArrayList<>();
313319
}

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
import java.nio.ByteBuffer;
7272
import java.util.*;
73+
import java.util.function.Function;
7374

7475
@Testcontainers(disabledWithoutDocker = true)
7576
class MilvusClientV2DockerTest {
@@ -1890,6 +1891,13 @@ void testNullableAndDefaultValue() {
18901891
.isNullable(Boolean.TRUE)
18911892
.maxLength(100)
18921893
.build());
1894+
collectionSchema.addField(AddFieldReq.builder()
1895+
.fieldName("arr")
1896+
.dataType(DataType.Array)
1897+
.elementType(DataType.Int32)
1898+
.isNullable(Boolean.TRUE)
1899+
.maxCapacity(100)
1900+
.build());
18931901

18941902
List<IndexParam> indexParams = new ArrayList<>();
18951903
indexParams.add(IndexParam.builder()
@@ -1918,7 +1926,11 @@ void testNullableAndDefaultValue() {
19181926
} else {
19191927
// row.add("flag", JsonNull.INSTANCE);
19201928
row.addProperty("desc", "AAA");
1929+
1930+
List<Integer> arr = Arrays.asList(5, 6);
1931+
row.add("arr", JsonUtils.toJsonTree(arr));
19211932
}
1933+
19221934
data.add(row);
19231935
}
19241936

@@ -1928,26 +1940,38 @@ void testNullableAndDefaultValue() {
19281940
.build());
19291941
Assertions.assertEquals(10, insertResp.getInsertCnt());
19301942

1943+
Function<Map<String, Object>, Void> checkFunc =
1944+
entity -> {
1945+
long id = (long)entity.get("id");
1946+
if (id%2 == 0) {
1947+
Assertions.assertEquals((int)id, entity.get("flag"));
1948+
Assertions.assertNull(entity.get("desc"));
1949+
Assertions.assertNull(entity.get("arr"));
1950+
} else {
1951+
Assertions.assertEquals(10, entity.get("flag"));
1952+
Assertions.assertEquals("AAA", entity.get("desc"));
1953+
Object obj = entity.get("arr");
1954+
Assertions.assertInstanceOf(List.class, obj);
1955+
List<Integer> arr = (List<Integer>)obj;
1956+
Assertions.assertEquals(2, arr.size());
1957+
Assertions.assertEquals(5, arr.get(0));
1958+
Assertions.assertEquals(6, arr.get(1));
1959+
}
1960+
return null;
1961+
};
19311962
// query
19321963
QueryResp queryResp = client.query(QueryReq.builder()
19331964
.collectionName(randomCollectionName)
19341965
.filter("id >= 0")
1935-
.outputFields(Arrays.asList("desc", "flag"))
1966+
.outputFields(Arrays.asList("desc", "flag", "arr"))
19361967
.consistencyLevel(ConsistencyLevel.STRONG)
19371968
.build());
19381969
List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
19391970
Assertions.assertEquals(10, queryResults.size());
19401971
System.out.println("Query results:");
19411972
for (QueryResp.QueryResult result : queryResults) {
19421973
Map<String, Object> entity = result.getEntity();
1943-
long id = (long)entity.get("id");
1944-
if (id%2 == 0) {
1945-
Assertions.assertEquals((int)id, entity.get("flag"));
1946-
Assertions.assertNull(entity.get("desc"));
1947-
} else {
1948-
Assertions.assertEquals(10, entity.get("flag"));
1949-
Assertions.assertEquals("AAA", entity.get("desc"));
1950-
}
1974+
checkFunc.apply(entity);
19511975
System.out.println(result);
19521976
}
19531977

@@ -1968,13 +1992,7 @@ void testNullableAndDefaultValue() {
19681992
for (SearchResp.SearchResult result : firstResults) {
19691993
long id = (long)result.getId();
19701994
Map<String, Object> entity = result.getEntity();
1971-
if (id%2 == 0) {
1972-
Assertions.assertEquals((int)id, entity.get("flag"));
1973-
Assertions.assertNull(entity.get("desc"));
1974-
} else {
1975-
Assertions.assertEquals(10, entity.get("flag"));
1976-
Assertions.assertEquals("AAA", entity.get("desc"));
1977-
}
1995+
checkFunc.apply(entity);
19781996
System.out.println(result);
19791997
}
19801998
}

0 commit comments

Comments
 (0)