Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
import io.milvus.param.resourcegroup.*;
import io.milvus.param.role.*;
import io.milvus.response.*;
import io.milvus.v2.service.collection.response.DescribeCollectionResp;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.utils.DataUtils;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -68,22 +65,31 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
protected static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
protected LogLevel logLevel = LogLevel.Info;

private ConcurrentHashMap<String, DescribeCollectionResponse> cacheCollectionInfo = new ConcurrentHashMap<>();
protected ConcurrentHashMap<String, DescribeCollectionResponse> cacheCollectionInfo = new ConcurrentHashMap<>();

protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();

protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();

protected abstract boolean clientIsReady();

protected abstract String currentDbName();

private String actualDbName(String overwriteName) {
if (StringUtils.isNotEmpty(overwriteName)) {
return overwriteName;
}
return currentDbName();
}

/**
* This method is for insert/upsert requests to reduce the rpc call of describeCollection()
* Always try to get the collection info from cache.
* If the cache doesn't have the collection info, call describeCollection() and cache it.
* If insert/upsert get server error, remove the cached collection info.
*/
private DescribeCollectionResponse getCollectionInfo(String databaseName, String collectionName, boolean forceUpdate) {
String key = combineCacheKey(databaseName, collectionName);
String key = GTsDict.CombineCollectionName(actualDbName(databaseName), collectionName);
DescribeCollectionResponse info = cacheCollectionInfo.get(key);
if (info == null || forceUpdate) {
String msg = String.format("Fail to describe collection '%s'", collectionName);
Expand All @@ -104,17 +110,6 @@ private DescribeCollectionResponse getCollectionInfo(String databaseName, String
return info;
}

private String combineCacheKey(String databaseName, String collectionName) {
if (collectionName == null || StringUtils.isBlank(collectionName)) {
throw new ParamException("Collection name is empty, not able to get collection info.");
}
String key = collectionName;
if (StringUtils.isNotEmpty(databaseName)) {
key = String.format("%s|%s", databaseName, collectionName);
}
return key;
}

/**
* insert/upsert return an error, but is not a RateLimit error,
* clean the cache so that the next insert will call describeCollection() to get the latest info.
Expand All @@ -127,7 +122,8 @@ private void cleanCacheIfFailed(Status status, String databaseName, String colle
}

private void removeCollectionCache(String databaseName, String collectionName) {
cacheCollectionInfo.remove(combineCacheKey(databaseName, collectionName));
String key = GTsDict.CombineCollectionName(actualDbName(databaseName), collectionName);
cacheCollectionInfo.remove(key);
}

private void waitForLoadingCollection(String databaseName, String collectionName, List<String> partitionNames,
Expand Down Expand Up @@ -658,7 +654,13 @@ public R<RpcStatus> dropCollection(@NonNull DropCollectionParam requestParam) {

Status response = blockingStub().dropCollection(dropCollectionRequest);
handleResponse(title, response);

// remove the collection schema cache
removeCollectionCache(dbName, collectionName);

// remove the last write timestamp for this collection
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().removeCollectionTs(key);
return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG));
} catch (StatusRuntimeException e) {
logError("{} RPC failed! Exception:{}", title, e);
Expand Down Expand Up @@ -1570,22 +1572,27 @@ public R<MutationResult> delete(@NonNull DeleteParam requestParam) {
}

logDebug(requestParam.toString());
String title = String.format("DeleteRequest collectionName:%s", requestParam.getCollectionName());
String dbName = requestParam.getDatabaseName();
String collectionName = requestParam.getCollectionName();
String title = String.format("DeleteRequest collectionName:%s", collectionName);

try {
DeleteRequest.Builder builder = DeleteRequest.newBuilder()
.setBase(MsgBase.newBuilder().setMsgType(MsgType.Delete).build())
.setCollectionName(requestParam.getCollectionName())
.setCollectionName(collectionName)
.setPartitionName(requestParam.getPartitionName())
.setExpr(requestParam.getExpr());

if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) {
builder.setDbName(requestParam.getDatabaseName());
if (StringUtils.isNotEmpty(dbName)) {
builder.setDbName(dbName);
}

MutationResult response = blockingStub().delete(builder.build());
handleResponse(title, response.getStatus());
GTsDict.getInstance().updateCollectionTs(requestParam.getCollectionName(), response.getTimestamp());

// update the last write timestamp for SESSION consistency
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp());
return R.success(response);
} catch (StatusRuntimeException e) {
logError("{} RPC failed! Exception:{}", title, e);
Expand Down Expand Up @@ -1639,10 +1646,14 @@ public R<MutationResult> insert(@NonNull InsertParam requestParam) {
return this.insert(requestParam);
}

// if illegal data, server fails to process insert, else succeed
// if illegal data, server fails to process insert, , clean the schema cache
// so that the next call of dml can update the cache
cleanCacheIfFailed(response.getStatus(), dbName, collectionName);
handleResponse(title, response.getStatus());
GTsDict.getInstance().updateCollectionTs(collectionName, response.getTimestamp());

// update the last write timestamp for SESSION consistency
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp());
return R.success(response);
} catch (StatusRuntimeException e) {
logError("{} RPC failed! Exception:{}", title, e);
Expand Down Expand Up @@ -1687,11 +1698,15 @@ public ListenableFuture<R<MutationResult>> insertAsync(InsertParam requestParam)
new FutureCallback<MutationResult>() {
@Override
public void onSuccess(MutationResult result) {
// if illegal data, server fails to process insert, else succeed
// if illegal data, server fails to process insert, clean the schema cache
// so that the next call of dml can update the cache
cleanCacheIfFailed(result.getStatus(), dbName, collectionName);
if (result.getStatus().getErrorCode() == ErrorCode.Success) {
logDebug("{} successfully!", title);
GTsDict.getInstance().updateCollectionTs(collectionName, result.getTimestamp());

// update the last write timestamp for SESSION consistency
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().updateCollectionTs(key, result.getTimestamp());
} else {
logError("{} failed:\n{}", title, result.getStatus().getReason());
}
Expand Down Expand Up @@ -1760,10 +1775,14 @@ public R<MutationResult> upsert(UpsertParam requestParam) {
return this.upsert(requestParam);
}

// if illegal data, server fails to process upsert, else succeed
// if illegal data, server fails to process upsert, clean the schema cache
// so that the next call of dml can update the cache
cleanCacheIfFailed(response.getStatus(), dbName, collectionName);
handleResponse(title, response.getStatus());
GTsDict.getInstance().updateCollectionTs(collectionName, response.getTimestamp());

// update the last write timestamp for SESSION consistency
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().updateCollectionTs(key, response.getTimestamp());
return R.success(response);
} catch (StatusRuntimeException e) {
logError("{} RPC failed! Exception:{}", title, e);
Expand Down Expand Up @@ -1807,11 +1826,15 @@ public ListenableFuture<R<MutationResult>> upsertAsync(UpsertParam requestParam)
new FutureCallback<MutationResult>() {
@Override
public void onSuccess(MutationResult result) {
// if illegal data, server fails to process upsert, else succeed
// if illegal data, server fails to process upsert, clean the schema cache
// so that the next call of dml can update the cache
cleanCacheIfFailed(result.getStatus(), dbName, collectionName);
if (result.getStatus().getErrorCode() == ErrorCode.Success) {
logDebug("{} successfully!", title);
GTsDict.getInstance().updateCollectionTs(collectionName, result.getTimestamp());

// update the last write timestamp for SESSION consistency
String key = GTsDict.CombineCollectionName(actualDbName(dbName), collectionName);
GTsDict.getInstance().updateCollectionTs(key, result.getTimestamp());
} else {
logError("{} failed:\n{}", title, result.getStatus().getReason());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
private final long rpcDeadlineMs;
private long timeoutMs = 0;
private RetryParam retryParam = RetryParam.newBuilder().build();
private String currentDatabaseName;

public MilvusServiceClient(@NonNull ConnectParam connectParam) {
this.rpcDeadlineMs = connectParam.getRpcDeadlineMs();

Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER), connectParam.getAuthorization());
if (StringUtils.isNotEmpty(connectParam.getDatabaseName())) {
currentDatabaseName = connectParam.getDatabaseName();
metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectParam.getDatabaseName());
}

Expand Down Expand Up @@ -201,6 +203,7 @@ protected MilvusServiceClient(MilvusServiceClient src) {
this.timeoutMs = src.timeoutMs;
this.logLevel = src.logLevel;
this.retryParam = src.retryParam;
this.currentDatabaseName = src.currentDatabaseName;
}

@Override
Expand All @@ -222,6 +225,11 @@ public boolean clientIsReady() {
return channel != null && !channel.isShutdown() && !channel.isTerminated();
}

@Override
protected String currentDbName() {
return currentDatabaseName;
}

@Override
public void close(long maxWaitSeconds) throws InterruptedException {
channel.shutdownNow();
Expand Down
21 changes: 21 additions & 0 deletions sdk-core/src/main/java/io/milvus/common/utils/GTsDict.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package io.milvus.common.utils;

import io.milvus.exception.ParamException;
import org.apache.commons.lang3.StringUtils;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

Expand All @@ -37,6 +40,16 @@ public static GTsDict getInstance() {
return TS_DICT;
}

public static String CombineCollectionName(String databaseName, String collectionName) {
if (collectionName == null || StringUtils.isBlank(collectionName)) {
throw new ParamException("Collection name is empty, not able to get collection info.");
}
if (StringUtils.isEmpty(databaseName)) {
databaseName = "default";
}
return String.format("%s_%s", databaseName, collectionName);
}

private ConcurrentMap<String, Long> tsDict = new ConcurrentHashMap<>();

public void updateCollectionTs(String collectionName, long ts) {
Expand All @@ -49,4 +62,12 @@ public void updateCollectionTs(String collectionName, long ts) {
public Long getCollectionTs(String collectionName) {
return tsDict.get(collectionName);
}

public void removeCollectionTs(String collectionName) {
tsDict.remove(collectionName);
}

public void cleanAllCollectionTs() {
tsDict.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;
import lombok.experimental.SuperBuilder;
import org.apache.commons.lang3.StringUtils;

import javax.net.ssl.SSLContext;
import java.net.URI;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

@Data
@SuperBuilder
@Builder
public class ConnectConfig {
@NonNull
private String uri;
Expand Down
34 changes: 33 additions & 1 deletion sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,24 @@ public class MilvusClientV2 {
public MilvusClientV2(ConnectConfig connectConfig) {
if (connectConfig != null) {
connect(connectConfig);

initServices(connectConfig.getDbName());

}
}

private void initServices(String dbName) {
this.databaseService.setCurrentDbName(dbName);
this.collectionService.setCurrentDbName(dbName);
this.indexService.setCurrentDbName(dbName);
this.vectorService.setCurrentDbName(dbName);
this.vectorService.cleanCollectionCache();
this.partitionService.setCurrentDbName(dbName);
this.rbacService.setCurrentDbName(dbName);
this.rgroupService.setCurrentDbName(dbName);
this.utilityService.setCurrentDbName(dbName);
}

/**
* connect to Milvus server
*
Expand Down Expand Up @@ -159,6 +175,22 @@ public void retryConfig(RetryConfig retryConfig) {
rpcUtils.retryConfig(retryConfig);
}

public MilvusClientV2 withRetry(RetryConfig retryConfig) {
rpcUtils.retryConfig(retryConfig);
return this;
}

public MilvusClientV2 withTimeout(long timeout, TimeUnit timeoutUnit) {
// the unit of rpcDeadlineMs is millisecond
// if the input timeout value is zero, rpcDeadlineMs is zero
// if the input timeout value is not zero and less than 1ms, it will be treated as 1ms
// if the input timeout value is larger than 1ms, it will be converted to an integer ms value
long nn = timeoutUnit.toNanos(timeout);
long ms = (nn == 0) ? 0 : (nn < 1000000 ? 1 : nn/1000000);
connectConfig.setRpcDeadlineMs(ms);
return this;
}

/////////////////////////////////////////////////////////////////////////////////////////////
// Database Operations
/////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -170,10 +202,10 @@ public void useDatabase(@NonNull String dbName) throws InterruptedException {
// check if database exists
clientUtils.checkDatabaseExist(this.getRpcStub(), dbName);
try {
this.vectorService.cleanCollectionCache();
this.connectConfig.setDbName(dbName);
this.close(3);
this.connect(this.connectConfig);
this.initServices(dbName);
} catch (InterruptedException e){
logger.error("close connect error");
throw new RuntimeException(e);
Expand Down
13 changes: 13 additions & 0 deletions sdk-core/src/main/java/io/milvus/v2/service/BaseService.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.milvus.v2.utils.DataUtils;
import io.milvus.v2.utils.RpcUtils;
import io.milvus.v2.utils.VectorUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -37,6 +38,18 @@ public class BaseService {
public DataUtils dataUtils = new DataUtils();
public VectorUtils vectorUtils = new VectorUtils();
public ConvertUtils convertUtils = new ConvertUtils();
private String currentDbName;

public void setCurrentDbName(String dbName) {
currentDbName = dbName;
}

protected String actualDbName(String overwriteName) {
if (StringUtils.isNotEmpty(overwriteName)) {
return overwriteName;
}
return currentDbName;
}

protected void checkCollectionExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String collectionName) {
HasCollectionRequest request = HasCollectionRequest.newBuilder().setCollectionName(collectionName).build();
Expand Down
Loading
Loading