Skip to content

Commit 6dfbb60

Browse files
committed
Optimize client pool
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent eb10051 commit 6dfbb60

File tree

6 files changed

+551
-57
lines changed

6 files changed

+551
-57
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
package io.milvus.pool;
2+
3+
import io.milvus.v2.exception.ErrorCode;
4+
import io.milvus.v2.exception.MilvusClientException;
5+
import org.apache.commons.pool2.impl.GenericKeyedObjectPool;
6+
import org.jetbrains.annotations.NotNull;
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
10+
import java.util.Objects;
11+
import java.util.concurrent.*;
12+
import java.util.concurrent.atomic.AtomicInteger;
13+
import java.util.concurrent.atomic.AtomicLong;
14+
import java.util.concurrent.locks.Lock;
15+
import java.util.concurrent.locks.ReentrantLock;
16+
17+
public class ClientCache<T> {
18+
public static final int THRESHOLD_INCREASE = 100;
19+
public static final int THRESHOLD_DECREASE = 50;
20+
21+
private static final Logger logger = LoggerFactory.getLogger(ClientCache.class);
22+
private final String key;
23+
private final GenericKeyedObjectPool<String, T> clientPool;
24+
private final CopyOnWriteArrayList<ClientWrapper<T>> activeClientList = new CopyOnWriteArrayList<>();
25+
private final CopyOnWriteArrayList<ClientWrapper<T>> retireClientList = new CopyOnWriteArrayList<>();
26+
private final ScheduledExecutorService scheduler;
27+
private final AtomicLong totalCallNumber = new AtomicLong(0L);
28+
private final Lock clientListLock;
29+
private long lastCheckMs = 0L;
30+
private float fetchClientPerSecond = 0.0F;
31+
32+
protected ClientCache(String key, GenericKeyedObjectPool<String, T> pool) {
33+
this.key = key;
34+
this.clientPool = pool;
35+
this.clientListLock = new ReentrantLock(true);
36+
37+
ThreadFactory threadFactory = new ThreadFactory() {
38+
@Override
39+
public Thread newThread(@NotNull Runnable r) {
40+
Thread t = new Thread(r);
41+
t.setPriority(Thread.MAX_PRIORITY); // set the highest priority for the timer
42+
return t;
43+
}
44+
};
45+
this.scheduler = Executors.newScheduledThreadPool(1, threadFactory);
46+
47+
startTimer(1000L);
48+
}
49+
50+
public void preparePool() {
51+
try {
52+
// preparePool() will create minIdlePerKey MilvusClient objects in advance, put the pre-created clients
53+
// into activeClientList
54+
clientPool.preparePool(this.key);
55+
int minIdlePerKey = clientPool.getMinIdlePerKey();
56+
for (int i = 0; i < minIdlePerKey; i++) {
57+
activeClientList.add(new ClientWrapper<>(clientPool.borrowObject(this.key)));
58+
}
59+
60+
if (logger.isDebugEnabled()) {
61+
logger.debug("ClientCache key: {} cache clients: {} ", key, activeClientList.size());
62+
logger.debug("Pool initialize idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key));
63+
}
64+
// System.out.printf("Key: %s, cache client: %d%n", key, activeClientList.size());
65+
// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key));
66+
} catch (Exception e) {
67+
logger.error("Failed to prepare pool {}, exception: ", key, e);
68+
throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e);
69+
}
70+
}
71+
72+
// this method is called in an interval, it does the following tasks:
73+
// - if QPS is high, borrow client from the pool and put into activeClientList
74+
// - if QPS is low, pick a client from activeClientList and put into retireClientList
75+
//
76+
// Most of gRPC implementations uses a single long-lived HTTP/2 connection, each HTTP/2 connections have a limit
77+
// on the number of concurrent streams which is default 100. When the number of active RPCs on the connection
78+
// reaches this limit, additional RPCs are queued in the client and must wait for active RPCs to finish
79+
// before they are sent.
80+
//
81+
// Treat qps >= 75 as high, qps <= 50 as low
82+
private void checkQPS() {
83+
if (activeClientList.isEmpty()) {
84+
// reset the last check time point
85+
lastCheckMs = System.currentTimeMillis();
86+
return;
87+
}
88+
89+
long totalCallNum = totalCallNumber.get();
90+
float perClientCall = (float) totalCallNum / activeClientList.size();
91+
long timeGapMs = System.currentTimeMillis() - lastCheckMs;
92+
if (timeGapMs == 0) {
93+
timeGapMs = 1; // avoid zero
94+
}
95+
float perClientPerSecond = perClientCall * 1000 / timeGapMs;
96+
this.fetchClientPerSecond = (float) (totalCallNum * 1000) / timeGapMs;
97+
if (logger.isDebugEnabled()) {
98+
99+
logger.debug("ClientCache key: {} fetchClientPerSecond: {} perClientPerSecond: {}, cached clients: {}",
100+
key, fetchClientPerSecond, perClientPerSecond, activeClientList.size());
101+
logger.debug("Pool idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key));
102+
}
103+
// System.out.printf("Key: %s, fetchClientPerSecond: %.2f, perClientPerSecond: %.2f, cache client: %d%n", key, fetchClientPerSecond, perClientPerSecond, activeClientList.size());
104+
// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key));
105+
106+
// reset the counter and the last check time point
107+
totalCallNumber.set(0L);
108+
lastCheckMs = System.currentTimeMillis();
109+
110+
if (perClientPerSecond >= THRESHOLD_INCREASE) {
111+
// try to create more clients to reduce the perClientPerSecond to under THRESHOLD_INCREASE
112+
// add no more than 3 clients since the qps could change during we're adding new clients
113+
// the next call of checkQPS() will add more clients if the perClientPerSecond is still high
114+
int expectedNum = (int) Math.ceil((double) totalCallNum / THRESHOLD_INCREASE);
115+
int moreNum = expectedNum - activeClientList.size();
116+
if (moreNum > 3) {
117+
moreNum = 3;
118+
}
119+
120+
for (int k = 0; k < moreNum; k++) {
121+
T client = fetchFromPool();
122+
// if the pool reaches MaxTotalPerKey, the new client is null
123+
if (client == null) {
124+
break;
125+
}
126+
127+
ClientWrapper<T> wrapper = new ClientWrapper<>(client);
128+
activeClientList.add(wrapper);
129+
130+
if (logger.isDebugEnabled()) {
131+
logger.debug("ClientCache key: {} borrows a client", key);
132+
}
133+
// System.out.printf("Key: %s borrows a client%n", key);
134+
}
135+
}
136+
137+
if (activeClientList.size() > 1 && perClientPerSecond <= THRESHOLD_DECREASE) {
138+
// if activeClientList has only one client, no need to retire it
139+
// otherwise, retire the max load client
140+
int maxLoad = -1000;
141+
int maxIndex = -1;
142+
for (int i = 0; i < activeClientList.size(); i++) {
143+
ClientWrapper<T> wrapper = activeClientList.get(i);
144+
int refCount = wrapper.getRefCount();
145+
if (refCount > maxLoad) {
146+
maxLoad = refCount;
147+
maxIndex = i;
148+
}
149+
}
150+
if (maxIndex >= 0) {
151+
ClientWrapper<T> wrapper = activeClientList.get(maxIndex);
152+
activeClientList.remove(maxIndex);
153+
retireClientList.add(wrapper);
154+
}
155+
}
156+
157+
// return the retired client to pool if ref count is zero
158+
returnRetiredClients();
159+
}
160+
161+
private void returnRetiredClients() {
162+
retireClientList.removeIf(wrapper -> {
163+
if (wrapper.getRefCount() <= 0) {
164+
returnToPool(wrapper.getClient());
165+
166+
if (logger.isDebugEnabled()) {
167+
logger.debug("ClientCache key: {} returns a client", key);
168+
}
169+
// System.out.printf("Key: %s returns a client%n", key);
170+
return true;
171+
}
172+
return false;
173+
});
174+
}
175+
176+
private void startTimer(long interval) {
177+
if (interval < 1000L) {
178+
interval = 1000L; // min 1000
179+
}
180+
181+
lastCheckMs = System.currentTimeMillis();
182+
scheduler.scheduleAtFixedRate(new Runnable() {
183+
@Override
184+
public void run() {
185+
checkQPS();
186+
}
187+
}, interval, interval, TimeUnit.MILLISECONDS);
188+
}
189+
190+
public void stopTimer() {
191+
scheduler.shutdown();
192+
}
193+
194+
public T getClient() {
195+
totalCallNumber.incrementAndGet();
196+
if (activeClientList.isEmpty()) {
197+
// multiple threads can run into this section, add a lock to ensure only one thread can fetch the first
198+
// client object, this section is entered only one time, the lock doesn't affect major performance
199+
clientListLock.lock();
200+
try {
201+
if (activeClientList.isEmpty()) {
202+
T client = fetchFromPool();
203+
if (client == null) {
204+
return null; // reach MaxTotalPerKey?
205+
}
206+
ClientWrapper<T> wrapper = new ClientWrapper<>(client);
207+
activeClientList.add(wrapper);
208+
return wrapper.getClient();
209+
}
210+
} finally {
211+
clientListLock.unlock();
212+
}
213+
}
214+
215+
// round-robin is not a good choice because the activeClientList is occasionally changed.
216+
// here we return the minimum load client, the for loop of CopyOnWriteArrayList is high performance
217+
// typically, the activeClientList is not a large list since a dozen of clients can take thousands of qps,
218+
// I suppose the loop is a cheap operation.
219+
int minLoad = Integer.MAX_VALUE;
220+
ClientWrapper<T> wrapper = null;
221+
for (ClientWrapper<T> tempWrapper : activeClientList) {
222+
if (tempWrapper.getRefCount() < minLoad) {
223+
minLoad = tempWrapper.getRefCount();
224+
wrapper = tempWrapper;
225+
}
226+
}
227+
if (wrapper == null) {
228+
// should not be here
229+
wrapper = activeClientList.get(0);
230+
}
231+
232+
return wrapper.getClient();
233+
}
234+
235+
public void returnClient(T grpcClient) {
236+
// for loop of CopyOnWriteArrayList is thread safe
237+
// this method only decrement the call number, the checkQPS timer will retire client accordingly
238+
for (ClientWrapper<T> wrapper : activeClientList) {
239+
if (wrapper.equals(grpcClient)) {
240+
wrapper.returnClient();
241+
return;
242+
}
243+
}
244+
for (ClientWrapper<T> wrapper : retireClientList) {
245+
if (wrapper.equals(grpcClient)) {
246+
wrapper.returnClient();
247+
return;
248+
}
249+
}
250+
}
251+
252+
private T fetchFromPool() {
253+
try {
254+
if (activeClientList.size() + retireClientList.size() >= clientPool.getMaxTotalPerKey()) {
255+
return null;
256+
}
257+
return clientPool.borrowObject(this.key);
258+
} catch (Exception e) {
259+
// the pool might return timeout exception if it could not get a client in PoolConfig.maxBlockWaitDuration
260+
logger.error("Failed to get client, exception: ", e);
261+
return null; // return null, let the ClientCache to handle
262+
}
263+
}
264+
265+
private void returnToPool(T grpcClient) {
266+
try {
267+
clientPool.returnObject(this.key, grpcClient);
268+
} catch (Exception e) {
269+
// the pool might return exception if the key doesn't exist or the grpcClient doesn't belong to this pool
270+
logger.error("Failed to return client, exception: ", e);
271+
throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e);
272+
}
273+
}
274+
275+
public float fetchClientPerSecond() {
276+
return this.fetchClientPerSecond;
277+
}
278+
279+
private static class ClientWrapper<T> {
280+
private final T client;
281+
private final AtomicInteger refCount = new AtomicInteger(0);
282+
283+
public ClientWrapper(T client) {
284+
this.client = client;
285+
}
286+
287+
@Override
288+
public int hashCode() {
289+
// the hash code of ClientWrapper is equal to MilvusClient hash code
290+
return this.client.hashCode();
291+
}
292+
293+
@Override
294+
public boolean equals(Object obj) {
295+
if (this == obj) return true;
296+
297+
if (obj == null) {
298+
return false;
299+
}
300+
301+
// obj is ClientWrapper
302+
if (this.getClass() == obj.getClass()) {
303+
return Objects.equals(this.client, ((ClientWrapper<?>) obj).client);
304+
}
305+
306+
// obj is MilvusClient
307+
if (this.client != null && this.client.getClass() == obj.getClass()) {
308+
return Objects.equals(this.client, obj);
309+
}
310+
return false;
311+
}
312+
313+
public T getClient() {
314+
this.refCount.incrementAndGet();
315+
return this.client;
316+
}
317+
318+
public void returnClient() {
319+
this.refCount.decrementAndGet();
320+
}
321+
322+
public int getRefCount() {
323+
return refCount.get();
324+
}
325+
}
326+
}

0 commit comments

Comments
 (0)