Skip to content

Commit a2fd09b

Browse files
committed
Support optimize() interface
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent bda9a55 commit a2fd09b

File tree

8 files changed

+1090
-0
lines changed

8 files changed

+1090
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package io.milvus.v2;
21+
22+
import com.google.gson.Gson;
23+
import com.google.gson.JsonObject;
24+
import io.milvus.v1.CommonUtils;
25+
import io.milvus.v2.client.ConnectConfig;
26+
import io.milvus.v2.client.MilvusClientV2;
27+
import io.milvus.v2.common.DataType;
28+
import io.milvus.v2.common.IndexParam;
29+
import io.milvus.v2.service.collection.request.AddFieldReq;
30+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
31+
import io.milvus.v2.service.collection.request.DropCollectionReq;
32+
import io.milvus.v2.service.collection.request.LoadCollectionReq;
33+
import io.milvus.v2.service.index.request.CreateIndexReq;
34+
import io.milvus.v2.service.utility.OptimizeTask;
35+
import io.milvus.v2.service.utility.request.FlushReq;
36+
import io.milvus.v2.service.utility.request.GetQuerySegmentInfoReq;
37+
import io.milvus.v2.service.utility.request.OptimizeReq;
38+
import io.milvus.v2.service.utility.response.GetQuerySegmentInfoResp;
39+
import io.milvus.v2.service.utility.response.OptimizeResp;
40+
import io.milvus.v2.service.vector.request.InsertReq;
41+
import io.milvus.v2.service.vector.response.InsertResp;
42+
43+
import java.util.*;
44+
45+
public class OptimizeExample {
46+
private static final String COLLECTION_NAME = "java_sdk_example_optimize_v2";
47+
private static final String ID_FIELD = "id";
48+
private static final String VECTOR_FIELD = "vector";
49+
private static final int VECTOR_DIM = 512;
50+
private static final int TOTAL_ROWS = 1_000_000;
51+
private static final int BATCH_SIZE = 10_000;
52+
53+
public static void main(String[] args) throws InterruptedException {
54+
ConnectConfig config = ConnectConfig.builder()
55+
.uri("http://localhost:19530")
56+
.build();
57+
MilvusClientV2 client = new MilvusClientV2(config);
58+
System.out.println(client.getServerVersion());
59+
60+
// Step 1: Drop and create collection
61+
System.out.println("========== Step 1: Create collection ==========");
62+
client.dropCollection(DropCollectionReq.builder()
63+
.collectionName(COLLECTION_NAME)
64+
.build());
65+
66+
CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
67+
.build();
68+
schema.addField(AddFieldReq.builder()
69+
.fieldName(ID_FIELD)
70+
.dataType(DataType.Int64)
71+
.isPrimaryKey(Boolean.TRUE)
72+
.autoID(Boolean.TRUE)
73+
.build());
74+
schema.addField(AddFieldReq.builder()
75+
.fieldName(VECTOR_FIELD)
76+
.dataType(DataType.FloatVector)
77+
.dimension(VECTOR_DIM)
78+
.build());
79+
80+
client.createCollection(CreateCollectionReq.builder()
81+
.collectionName(COLLECTION_NAME)
82+
.collectionSchema(schema)
83+
.build());
84+
System.out.printf("Collection '%s' created%n", COLLECTION_NAME);
85+
86+
// Step 2: Insert one million rows, size is 2GB when dimension is 512
87+
System.out.println("========== Step 2: Insert 1,000,000 rows ==========");
88+
Gson gson = new Gson();
89+
int totalInserted = 0;
90+
for (int batch = 0; batch < TOTAL_ROWS / BATCH_SIZE; batch++) {
91+
List<JsonObject> rows = new ArrayList<>();
92+
for (int i = 0; i < BATCH_SIZE; i++) {
93+
JsonObject row = new JsonObject();
94+
row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
95+
rows.add(row);
96+
}
97+
InsertResp resp = client.insert(InsertReq.builder()
98+
.collectionName(COLLECTION_NAME)
99+
.data(rows)
100+
.build());
101+
totalInserted += (int) resp.getInsertCnt();
102+
if ((batch + 1) % 10 == 0) {
103+
System.out.printf(" Inserted %d / %d rows%n", totalInserted, TOTAL_ROWS);
104+
}
105+
}
106+
client.flush(FlushReq.builder().collectionNames(Collections.singletonList(COLLECTION_NAME)).build());
107+
System.out.printf("Total inserted: %d rows%n", totalInserted);
108+
109+
// Step 3: Create IVF_FLAT index
110+
System.out.println("========== Step 3: Create IVF_FLAT index ==========");
111+
Map<String, Object> extraParams = new HashMap<>();
112+
extraParams.put("nlist", 32);
113+
IndexParam indexParam = IndexParam.builder()
114+
.fieldName(VECTOR_FIELD)
115+
.indexType(IndexParam.IndexType.IVF_FLAT)
116+
.metricType(IndexParam.MetricType.L2)
117+
.extraParams(extraParams)
118+
.build();
119+
client.createIndex(CreateIndexReq.builder()
120+
.collectionName(COLLECTION_NAME)
121+
.indexParams(Collections.singletonList(indexParam))
122+
.timeout(100000L)
123+
.build());
124+
System.out.println("IVF_FLAT index created");
125+
126+
// Step 4: Load collection
127+
System.out.println("========== Step 4: Load collection ==========");
128+
client.loadCollection(LoadCollectionReq.builder()
129+
.collectionName(COLLECTION_NAME)
130+
.build());
131+
System.out.println("Collection loaded");
132+
133+
// Step 5: Check segments before optimize
134+
System.out.println("========== Step 5: Query segment info (before optimize) ==========");
135+
printSegmentInfo(client);
136+
137+
// Step 6: Optimize with targetSize=4GB, synchronous
138+
// Data will be merged into one segment because total size is 2GB, which is smaller than targetSize.
139+
// In standalone Milvus, performance will be the best if data is merged into one segment.
140+
// But in cluster Milvus, it's recommended to have multiple segments for better load balancing and query performance,
141+
// so you need to carefully set targetSize based on your data size and cluster configuration.
142+
System.out.println("========== Step 6: Optimize (targetSize=4GB, sync) ==========");
143+
long startTime = System.currentTimeMillis();
144+
OptimizeTask task = client.optimize(OptimizeReq.builder()
145+
.collectionName(COLLECTION_NAME)
146+
.targetSize("4GB")
147+
.build());
148+
OptimizeResp result = task.getResult(null);
149+
long elapsed = System.currentTimeMillis() - startTime;
150+
System.out.printf("Optimize completed in %.1f seconds%n", elapsed / 1000.0);
151+
System.out.printf(" Status: %s%n", result.getStatus());
152+
System.out.printf(" Compaction ID: %d%n", result.getCompactionId());
153+
System.out.printf(" Progress: %s%n", result.getProgress());
154+
155+
// Step 8: Check segments after optimize
156+
System.out.println("========== Step 8: Query segment info (after optimize) ==========");
157+
while (true) {
158+
int segmentCount = printSegmentInfo(client);
159+
if (segmentCount == 1) {
160+
System.out.println("Optimization successful, only one segment remains");
161+
break;
162+
}
163+
System.out.println("Waiting for optimization to complete...");
164+
Thread.sleep(1000);
165+
}
166+
167+
client.close(5);
168+
}
169+
170+
private static int printSegmentInfo(MilvusClientV2 client) {
171+
GetQuerySegmentInfoResp segResp = client.getQuerySegmentInfo(
172+
GetQuerySegmentInfoReq.builder()
173+
.collectionName(COLLECTION_NAME)
174+
.build());
175+
List<GetQuerySegmentInfoResp.QuerySegmentInfo> segments = segResp.getSegmentInfos();
176+
System.out.printf(" Total segments: %d%n", segments.size());
177+
long totalRows = 0;
178+
for (GetQuerySegmentInfoResp.QuerySegmentInfo seg : segments) {
179+
System.out.printf(" Segment %d: rows=%d, state=%s, level=%s%n",
180+
seg.getSegmentID(), seg.getNumOfRows(), seg.getState(), seg.getLevel());
181+
totalRows += seg.getNumOfRows();
182+
}
183+
System.out.printf(" Total rows across segments: %d%n", totalRows);
184+
return segments.size();
185+
}
186+
}

0 commit comments

Comments
 (0)