Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ public class CommonData {
public static String fieldSparseVector = "fieldSparseVector";
public static int addMaxLength = 99;

// Struct related fields
public static String defaultStructCollection = "StructCollection";
public static String fieldStruct = "fieldStruct";
public static String structFieldInt32 = "structInt32";
public static String structFieldVarchar = "structVarchar";
public static String structFieldFloatVector1 = "structFloatVector1";
public static String structFieldFloatVector2 = "structFloatVector2";
public static int structVectorDim = 128;
public static int structMaxCapacity = 100;


public static String partitionName = "partitionName";
// 快速创建时候的默认向量filed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DescribeCollectionReq;
import io.milvus.v2.service.collection.request.LoadCollectionReq;
Expand Down Expand Up @@ -275,6 +276,110 @@ public static String createNewCollection(int dim, String collectionName, DataTyp
log.info("create collection:" + collectionName);
return collectionName;
}

/**
* Create a new collection with Varchar primary key
*
* @param dim dimension of the vector field
* @param collectionName collection name
* @param vectorType vector data type
* @return collection name
*/
public static String createNewCollectionWithVarcharPK(int dim, String collectionName, DataType vectorType) {
if (collectionName == null || collectionName.equals("")) {
collectionName = "Collection_" + GenerateUtil.getRandomString(10);
}
// Use Varchar as primary key
CreateCollectionReq.FieldSchema fieldVarcharPK = CreateCollectionReq.FieldSchema.builder()
.autoID(false)
.dataType(DataType.VarChar)
.isPrimaryKey(true)
.name(CommonData.fieldVarchar)
.maxLength(100)
.build();
CreateCollectionReq.FieldSchema fieldInt64 = CreateCollectionReq.FieldSchema.builder()
.dataType(io.milvus.v2.common.DataType.Int64)
.isPrimaryKey(false)
.name(CommonData.fieldInt64)
.build();
CreateCollectionReq.FieldSchema fieldInt32 = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.Int32)
.name(CommonData.fieldInt32)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldInt8 = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.Int8)
.name(CommonData.fieldInt8)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldFloat = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.Float)
.name(CommonData.fieldFloat)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldDouble = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.Double)
.name(CommonData.fieldDouble)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldBool = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.Bool)
.name(CommonData.fieldBool)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldJson = CreateCollectionReq.FieldSchema.builder()
.dataType(DataType.JSON)
.name(CommonData.fieldJson)
.isPrimaryKey(false)
.build();
CreateCollectionReq.FieldSchema fieldVector = CreateCollectionReq.FieldSchema.builder()
.dataType(vectorType)
.isPrimaryKey(false)
.build();
if (vectorType == DataType.FloatVector) {
fieldVector.setDimension(dim);
fieldVector.setName(CommonData.fieldFloatVector);
}
if (vectorType == DataType.BinaryVector) {
fieldVector.setDimension(dim);
fieldVector.setName(CommonData.fieldBinaryVector);
}
if (vectorType == DataType.Float16Vector) {
fieldVector.setDimension(dim);
fieldVector.setName(CommonData.fieldFloat16Vector);
}
if (vectorType == DataType.BFloat16Vector) {
fieldVector.setDimension(dim);
fieldVector.setName(CommonData.fieldBF16Vector);
}
if (vectorType == DataType.SparseFloatVector) {
fieldVector.setName(CommonData.fieldSparseVector);
}
List<CreateCollectionReq.FieldSchema> fieldSchemaList = new ArrayList<>();
fieldSchemaList.add(fieldVarcharPK);
fieldSchemaList.add(fieldInt64);
fieldSchemaList.add(fieldInt32);
fieldSchemaList.add(fieldInt8);
fieldSchemaList.add(fieldFloat);
fieldSchemaList.add(fieldDouble);
fieldSchemaList.add(fieldBool);
fieldSchemaList.add(fieldJson);
fieldSchemaList.add(fieldVector);
CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
.fieldSchemaList(fieldSchemaList)
.build();
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionSchema(collectionSchema)
.collectionName(collectionName)
.enableDynamicField(false)
.description("collection with varchar primary key")
.numShards(1)
.build();
milvusClientV2.createCollection(createCollectionReq);
log.info("create collection with varchar pk:" + collectionName);
return collectionName;
}

public static String createNewCollectionWithDatabase(int dim, String collectionName, DataType vectorType,String databaseName) {
if (collectionName == null || collectionName.equals("")) {
collectionName = "Collection_" + GenerateUtil.getRandomString(10);
Expand Down Expand Up @@ -929,6 +1034,60 @@ public static List<JsonObject> generateDefaultData(long startId, long num, int d
return jsonList;
}

/**
* Generate data with varchar primary key for collection created by createNewCollectionWithVarcharPK
*
* @param startId start id
* @param num number of entities to generate
* @param dim dimension of vector
* @param vectorType vector data type
* @return List of JsonObject representing the data rows
*/
public static List<JsonObject> generateDataWithVarcharPK(long startId, long num, int dim, DataType vectorType) {
List<JsonObject> jsonList = new ArrayList<>();
Random ran = new Random();
Gson gson = new Gson();
for (long i = startId; i < (num + startId); i++) {
JsonObject row = new JsonObject();
// Use varchar as primary key
row.addProperty(CommonData.fieldVarchar, "Str" + i);
row.addProperty(CommonData.fieldInt64, i);
row.addProperty(CommonData.fieldInt32, (int) i % 32767);
row.addProperty(CommonData.fieldInt8, (short) i % 127);
row.addProperty(CommonData.fieldDouble, (double) i);
row.addProperty(CommonData.fieldBool, i % 2 == 0);
row.addProperty(CommonData.fieldFloat, (float) i);
// Generate vector based on type
if (vectorType == DataType.FloatVector) {
List<Float> vector = new ArrayList<>();
for (int k = 0; k < dim; ++k) {
vector.add(ran.nextFloat());
}
row.add(CommonData.fieldFloatVector, gson.toJsonTree(vector));
}
if (vectorType == DataType.BinaryVector) {
row.add(CommonData.fieldBinaryVector, gson.toJsonTree(generateBinaryVector(dim).array()));
}
if (vectorType == DataType.Float16Vector) {
row.add(CommonData.fieldFloat16Vector, gson.toJsonTree(generateFloat16Vector(dim).array()));
}
if (vectorType == DataType.BFloat16Vector) {
row.add(CommonData.fieldBF16Vector, gson.toJsonTree(generateBF16Vector(dim).array()));
}
if (vectorType == DataType.SparseFloatVector) {
row.add(CommonData.fieldSparseVector, gson.toJsonTree(generateSparseVector(dim)));
}
JsonObject json = new JsonObject();
json.addProperty(CommonData.fieldInt64, (int) i % 32767);
json.addProperty(CommonData.fieldInt32, (int) i % 32767);
json.addProperty(CommonData.fieldDouble, (double) i);
json.addProperty(CommonData.fieldFloat, (float) i);
row.add(CommonData.fieldJson, json);
jsonList.add(row);
}
return jsonList;
}

public static List<JsonObject> generateDefaultDataWithDynamic(long startId, long num, int dim, DataType vectorType) {
List<JsonObject> jsonList = new ArrayList<>();
Random ran = new Random();
Expand Down Expand Up @@ -1975,6 +2134,192 @@ public static void multiFilesUpload(String path, List<List<String>> batchFiles)

}

// ==================== Struct Array Related Methods ====================

/**
* Create a collection schema with Struct field containing vectors
*
* @param collectionName collection name
* @param dim vector dimension
* @return collection name
*/
public static String createStructCollection(String collectionName, int dim) {
if (collectionName == null || collectionName.isEmpty()) {
collectionName = "StructCollection_" + GenerateUtil.getRandomString(10);
}

CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
.build();

// Primary key field
collectionSchema.addField(AddFieldReq.builder()
.fieldName(CommonData.fieldInt64)
.dataType(DataType.Int64)
.isPrimaryKey(true)
.autoID(false)
.build());

// Regular float vector field
collectionSchema.addField(AddFieldReq.builder()
.fieldName(CommonData.fieldFloatVector)
.dataType(DataType.FloatVector)
.dimension(dim)
.build());

// Struct array field with multiple sub-fields including vectors
collectionSchema.addField(AddFieldReq.builder()
.fieldName(CommonData.fieldStruct)
.description("struct array field with vectors")
.dataType(DataType.Array)
.elementType(DataType.Struct)
.maxCapacity(CommonData.structMaxCapacity)
.addStructField(AddFieldReq.builder()
.fieldName(CommonData.structFieldInt32)
.description("int32 field in struct")
.dataType(DataType.Int32)
.build())
.addStructField(AddFieldReq.builder()
.fieldName(CommonData.structFieldVarchar)
.description("varchar field in struct")
.dataType(DataType.VarChar)
.maxLength(1024)
.build())
.addStructField(AddFieldReq.builder()
.fieldName(CommonData.structFieldFloatVector1)
.description("first float vector in struct")
.dataType(DataType.FloatVector)
.dimension(dim)
.build())
.addStructField(AddFieldReq.builder()
.fieldName(CommonData.structFieldFloatVector2)
.description("second float vector in struct")
.dataType(DataType.FloatVector)
.dimension(dim)
.build())
.build());

CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionName(collectionName)
.collectionSchema(collectionSchema)
.enableDynamicField(false)
.numShards(1)
.build();

milvusClientV2.createCollection(createCollectionReq);
log.info("Created struct collection: " + collectionName);
return collectionName;
}

/**
* Generate data for struct collection
*
* @param startId start id
* @param count number of rows
* @param dim vector dimension
* @return list of JsonObject data
*/
public static List<JsonObject> generateStructData(long startId, long count, int dim) {
List<JsonObject> dataList = new ArrayList<>();
Random random = new Random();

for (long i = startId; i < startId + count; i++) {
JsonObject row = new JsonObject();
row.addProperty(CommonData.fieldInt64, i);

// Regular float vector
List<Float> vector = GenerateUtil.generateFloatVector(1, 6, dim).get(0);
row.add(CommonData.fieldFloatVector, new com.google.gson.Gson().toJsonTree(vector));

// Struct array - each row has 3-10 struct elements
int structCount = random.nextInt(8) + 3;
JsonArray structArray = new JsonArray();
for (int j = 0; j < structCount; j++) {
JsonObject structElement = new JsonObject();
structElement.addProperty(CommonData.structFieldInt32, random.nextInt(10000));
structElement.addProperty(CommonData.structFieldVarchar, "struct_desc_" + i + "_" + j);

// First vector in struct
List<Float> vec1 = GenerateUtil.generateFloatVector(1, 6, dim).get(0);
structElement.add(CommonData.structFieldFloatVector1, new com.google.gson.Gson().toJsonTree(vec1));

// Second vector in struct
List<Float> vec2 = GenerateUtil.generateFloatVector(1, 6, dim).get(0);
structElement.add(CommonData.structFieldFloatVector2, new com.google.gson.Gson().toJsonTree(vec2));

structArray.add(structElement);
}
row.add(CommonData.fieldStruct, structArray);

dataList.add(row);
}
return dataList;
}
Comment thread
yongpengli-z marked this conversation as resolved.

/**
* Create embedding list index for struct vector field
*
* @param collectionName collection name
* @param structFieldName struct field name
* @param vectorFieldName vector field name in struct
* @param indexName index name
* @param metricType metric type (MAX_SIM_COSINE, MAX_SIM_IP, MAX_SIM_L2)
*/
public static void createStructVectorIndex(String collectionName, String structFieldName,
String vectorFieldName, String indexName,
IndexParam.MetricType metricType) {
String fullFieldName = String.format("%s[%s]", structFieldName, vectorFieldName);
IndexParam indexParam = IndexParam.builder()
.fieldName(fullFieldName)
.indexName(indexName)
.indexType(IndexParam.IndexType.HNSW)
.metricType(metricType)
.extraParams(new HashMap<String, Object>() {{
put("M", 16);
put("efConstruction", 200);
}})
.build();
Comment thread
yongpengli-z marked this conversation as resolved.

milvusClientV2.createIndex(CreateIndexReq.builder()
.collectionName(collectionName)
.indexParams(Collections.singletonList(indexParam))
.build());
log.info("Created struct vector index: " + indexName + " on " + fullFieldName);
}

/**
* Generate EmbeddingList from struct query result
*
* @param structData struct field data from query result
* @param vectorFieldName vector field name in struct
* @return EmbeddingList
*/
public static EmbeddingList generateEmbeddingListFromStruct(List<Map<String, Object>> structData,
String vectorFieldName) {
EmbeddingList embeddingList = new EmbeddingList();
for (Map<String, Object> struct : structData) {
@SuppressWarnings("unchecked")
List<Float> vector = (List<Float>) struct.get(vectorFieldName);
embeddingList.add(new FloatVec(vector));
}
return embeddingList;
}

/**
* Generate random EmbeddingList for search
*
* @param vectorCount number of vectors in embedding list
* @param dim vector dimension
* @return EmbeddingList
*/
public static EmbeddingList generateRandomEmbeddingList(int vectorCount, int dim) {
EmbeddingList embeddingList = new EmbeddingList();
for (int i = 0; i < vectorCount; i++) {
List<Float> vector = GenerateUtil.generateFloatVector(1, 6, dim).get(0);
embeddingList.add(new FloatVec(vector));
}
return embeddingList;
}

}


Loading
Loading