Skip to content

Commit 6cc825d

Browse files
authored
Support Get() (#431)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 8c5fc03 commit 6cc825d

14 files changed

Lines changed: 299 additions & 41 deletions

File tree

examples/src/v2/hybrid_search.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ main(int argc, char* argv[]) {
5555
util::CheckStatus("create collection: " + collection_name, status);
5656

5757
// create index
58-
milvus::IndexDesc index_dense(field_dense, "", milvus::IndexType::DISKANN, milvus::MetricType::COSINE);
59-
milvus::IndexDesc index_sparse(field_sparse, "", milvus::IndexType::SPARSE_INVERTED_INDEX, milvus::MetricType::IP);
60-
status = client->CreateIndex(milvus::CreateIndexRequest()
61-
.WithCollectionName(collection_name)
62-
.AddIndex(std::move(index_dense))
63-
.AddIndex(std::move(index_sparse)));
58+
std::vector<milvus::IndexDesc> indexes = {
59+
milvus::IndexDesc(field_dense, "", milvus::IndexType::DISKANN, milvus::MetricType::COSINE),
60+
milvus::IndexDesc(field_sparse, "", milvus::IndexType::SPARSE_INVERTED_INDEX, milvus::MetricType::IP)};
61+
status = client->CreateIndex(
62+
milvus::CreateIndexRequest().WithCollectionName(collection_name).WithIndexes(std::move(indexes)));
6463
util::CheckStatus("create indexes on collection", status);
6564

6665
// tell server prepare to load collection
@@ -102,35 +101,32 @@ main(int argc, char* argv[]) {
102101

103102
{
104103
// do hybrid search
105-
auto request =
106-
milvus::HybridSearchRequest()
107-
.WithCollectionName(collection_name)
108-
.WithLimit(10)
109-
.AddOutputField(field_flag)
110-
.AddOutputField(field_text)
111-
// .AddOutputField(field_sparse)
112-
// set to BOUNDED level to accept data inconsistence within a time window(default is 5 seconds)
113-
.WithConsistencyLevel(milvus::ConsistencyLevel::BOUNDED);
114-
115-
// sub search request 1 for dense vector
116104
auto sub_req1 = milvus::SubSearchRequest()
117105
.WithLimit(5)
118106
.WithAnnsField(field_dense)
119107
.WithFilter(field_flag + " == 5")
120108
.AddFloatVector(util::GenerateFloatVector(dimension));
121-
request.AddSubRequest(std::make_shared<milvus::SubSearchRequest>(std::move(sub_req1)));
122109

123-
// sub search request 2 for sparse vector
124110
auto sub_req2 = milvus::SubSearchRequest()
125111
.WithLimit(15)
126112
.WithAnnsField(field_sparse)
127113
.WithFilter(field_flag + " in [1, 3]")
128114
.AddSparseVector(util::GenerateSparseVector(50));
129-
request.AddSubRequest(std::make_shared<milvus::SubSearchRequest>(std::move(sub_req2)));
130115

131-
// define reranker
132116
auto reranker = std::make_shared<milvus::WeightedRerank>(std::vector<float>{0.5, 0.5});
133-
request.SetRerank(reranker);
117+
118+
auto request =
119+
milvus::HybridSearchRequest()
120+
.WithCollectionName(collection_name)
121+
.WithLimit(10)
122+
.AddSubRequest(std::make_shared<milvus::SubSearchRequest>(std::move(sub_req1)))
123+
.AddSubRequest(std::make_shared<milvus::SubSearchRequest>(std::move(sub_req2)))
124+
.WithRerank(reranker)
125+
.AddOutputField(field_flag)
126+
.AddOutputField(field_text)
127+
// .AddOutputField(field_sparse)
128+
// set to BOUNDED level to accept data inconsistence within a time window(default is 5 seconds)
129+
.WithConsistencyLevel(milvus::ConsistencyLevel::BOUNDED);
134130

135131
milvus::SearchResponse response;
136132
status = client->HybridSearch(request, response);

examples/src/v2/simple.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,46 @@ main(int argc, char* argv[]) {
6262
std::cout << resp_insert.Results().InsertCount() << " rows inserted by row-based." << std::endl;
6363

6464
// search
65-
auto request = milvus::SearchRequest()
66-
.WithCollectionName(collection_name)
67-
.WithLimit(3)
68-
.WithAnnsField(field_vector)
69-
.AddFloatVector(util::GenerateFloatVector(dimension))
70-
.WithConsistencyLevel(milvus::ConsistencyLevel::STRONG);
71-
72-
milvus::SearchResponse response;
73-
status = client->Search(request, response);
74-
util::CheckStatus("search", status);
75-
76-
for (auto& result : response.Results().Results()) {
77-
std::cout << "Result of one target vector:" << std::endl;
65+
{
66+
auto request = milvus::SearchRequest()
67+
.WithCollectionName(collection_name)
68+
.WithLimit(3)
69+
.WithAnnsField(field_vector)
70+
.AddFloatVector(util::GenerateFloatVector(dimension))
71+
.WithConsistencyLevel(milvus::ConsistencyLevel::STRONG);
72+
73+
milvus::SearchResponse response;
74+
status = client->Search(request, response);
75+
util::CheckStatus("search", status);
76+
77+
for (auto& result : response.Results().Results()) {
78+
std::cout << "Result of one target vector:" << std::endl;
79+
milvus::EntityRows output_rows;
80+
status = result.OutputRows(output_rows);
81+
util::CheckStatus("get output rows", status);
82+
for (const auto& row : output_rows) {
83+
std::cout << "\t" << row << std::endl;
84+
}
85+
}
86+
}
87+
88+
// get records by ids
89+
{
90+
std::vector<int64_t> ids = {5, 1, 10, 8};
91+
auto request = milvus::GetRequest()
92+
.WithCollectionName(collection_name)
93+
.WithIDs(std::move(ids))
94+
.AddOutputField(field_vector);
95+
96+
milvus::GetResponse response;
97+
status = client->Get(request, response);
98+
util::CheckStatus("get", status);
99+
100+
auto query_results = response.Results();
78101
milvus::EntityRows output_rows;
79-
status = result.OutputRows(output_rows);
102+
status = query_results.OutputRows(output_rows);
80103
util::CheckStatus("get output rows", status);
104+
std::cout << "Get results:" << std::endl;
81105
for (const auto& row : output_rows) {
82106
std::cout << "\t" << row << std::endl;
83107
}

src/impl/MilvusClientImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result
12151215
CollectionDescPtr collection_desc;
12161216
getCollectionDesc(arguments.CollectionName(), false, collection_desc);
12171217
if (collection_desc != nullptr) {
1218-
pk_name = collection_desc->Schema().Name();
1218+
pk_name = collection_desc->Schema().PrimaryFieldName();
12191219
}
12201220
}
12211221
return ConvertSearchResults(response, pk_name, results);
@@ -1307,7 +1307,7 @@ MilvusClientImpl::HybridSearch(const HybridSearchArguments& arguments, SearchRes
13071307
CollectionDescPtr collection_desc;
13081308
getCollectionDesc(arguments.CollectionName(), false, collection_desc);
13091309
if (collection_desc != nullptr) {
1310-
pk_name = collection_desc->Schema().Name();
1310+
pk_name = collection_desc->Schema().PrimaryFieldName();
13111311
}
13121312
}
13131313
return ConvertSearchResults(response, pk_name, results);

src/impl/MilvusClientV2Impl.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ MilvusClientV2Impl::Search(const SearchRequest& request, SearchResponse& respons
13481348
CollectionDescPtr collection_desc;
13491349
getCollectionDesc(request.DatabaseName(), request.CollectionName(), false, collection_desc);
13501350
if (collection_desc != nullptr) {
1351-
pk_name = collection_desc->Schema().Name();
1351+
pk_name = collection_desc->Schema().PrimaryFieldName();
13521352
}
13531353
}
13541354
auto status = ConvertSearchResults(rpc_response, pk_name, results);
@@ -1451,7 +1451,7 @@ MilvusClientV2Impl::HybridSearch(const HybridSearchRequest& request, HybridSearc
14511451
CollectionDescPtr collection_desc;
14521452
getCollectionDesc(request.DatabaseName(), request.CollectionName(), false, collection_desc);
14531453
if (collection_desc != nullptr) {
1454-
pk_name = collection_desc->Schema().Name();
1454+
pk_name = collection_desc->Schema().PrimaryFieldName();
14551455
}
14561456
}
14571457
auto status = ConvertSearchResults(rpc_response, pk_name, results);
@@ -1481,6 +1481,44 @@ MilvusClientV2Impl::Query(const QueryRequest& request, QueryResponse& response)
14811481
post);
14821482
}
14831483

1484+
Status
1485+
MilvusClientV2Impl::Get(const GetRequest& request, GetResponse& response) {
1486+
CollectionDescPtr collection_desc;
1487+
auto status = getCollectionDesc(request.DatabaseName(), request.CollectionName(), false, collection_desc);
1488+
if (!status.IsOk()) {
1489+
return status;
1490+
}
1491+
if (collection_desc == nullptr) {
1492+
return {StatusCode::UNKNOWN_ERROR, "Unable to get collection schema"};
1493+
}
1494+
auto pk_name = collection_desc->Schema().PrimaryFieldName();
1495+
1496+
nlohmann::json filter_template;
1497+
const auto& id_array = request.IDs();
1498+
if (id_array.IsIntegerID()) {
1499+
filter_template = id_array.IntIDArray();
1500+
} else {
1501+
filter_template = id_array.StrIDArray();
1502+
}
1503+
1504+
std::set<std::string> partition_names = request.PartitionNames(); // this is a copy
1505+
std::set<std::string> output_fields = request.OutputFields(); // this is a copy
1506+
1507+
// use filter template to pass the id array
1508+
static const std::string ids_key = "pks_to_get";
1509+
auto filter = pk_name + " in {" + ids_key + "}";
1510+
auto actual_request = QueryRequest()
1511+
.WithDatabaseName(request.DatabaseName())
1512+
.WithCollectionName(request.CollectionName())
1513+
.WithPartitionNames(std::move(partition_names))
1514+
.WithConsistencyLevel(request.GetConsistencyLevel())
1515+
.WithFilter(filter)
1516+
.AddFilterTemplate(ids_key, filter_template)
1517+
.WithOutputFields(std::move(output_fields));
1518+
1519+
return Query(actual_request, response);
1520+
}
1521+
14841522
Status
14851523
MilvusClientV2Impl::QueryIterator(QueryIteratorRequest& request, QueryIteratorPtr& iterator) {
14861524
auto status = iteratorPrepare(request);

src/impl/MilvusClientV2Impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ class MilvusClientV2Impl : public MilvusClientV2 {
198198
Status
199199
Query(const QueryRequest& request, QueryResponse& response) final;
200200

201+
Status
202+
Get(const GetRequest& request, GetResponse& response) final;
203+
201204
Status
202205
QueryIterator(QueryIteratorRequest& request, QueryIteratorPtr& response) final;
203206

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#include "milvus/request/dql/GetRequest.h"
18+
19+
#include <memory>
20+
21+
namespace milvus {
22+
23+
const IDArray&
24+
GetRequest::IDs() const {
25+
return ids_;
26+
}
27+
28+
void
29+
GetRequest::SetIDs(std::vector<int64_t>&& id_array) {
30+
ids_ = IDArray(std::move(id_array));
31+
}
32+
33+
void
34+
GetRequest::SetIDs(std::vector<std::string>&& id_array) {
35+
ids_ = IDArray(std::move(id_array));
36+
}
37+
38+
GetRequest&
39+
GetRequest::WithIDs(std::vector<int64_t>&& id_array) {
40+
SetIDs(std::move(id_array));
41+
return *this;
42+
}
43+
44+
GetRequest&
45+
GetRequest::WithIDs(std::vector<std::string>&& id_array) {
46+
SetIDs(std::move(id_array));
47+
return *this;
48+
}
49+
50+
} // namespace milvus

src/impl/request/dql/SearchRequest.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ SearchRequest::WithStrictGroupSize(bool strict_group_size) {
169169
return *this;
170170
}
171171

172+
SearchRequest&
173+
SearchRequest::WithRadius(double radius) {
174+
SetRadius(radius);
175+
return *this;
176+
}
177+
178+
SearchRequest&
179+
SearchRequest::WithRangeFilter(double filter) {
180+
SetRangeFilter(filter);
181+
return *this;
182+
}
183+
172184
const FunctionScorePtr&
173185
SearchRequest::Rerank() const {
174186
return ranker_;

src/impl/types/SearchResults.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,15 @@ SearchResults::Results() const {
164164
return nq_results_;
165165
}
166166

167+
const std::vector<float>&
168+
SearchResults::Recalls() const {
169+
return recalls_;
170+
}
171+
172+
SearchResults&
173+
SearchResults::WithRecalls(std::vector<float>&& recalls) {
174+
recalls_ = std::move(recalls);
175+
return *this;
176+
}
177+
167178
} // namespace milvus

src/impl/utils/DqlUtils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,12 @@ ConvertSearchResults(const proto::milvus::SearchResults& rpc_results, const std:
13551355
offset += item_topk;
13561356
}
13571357

1358-
results = SearchResults(std::move(single_results));
1358+
std::vector<float> recalls;
1359+
for (auto recall : result_data.recalls()) {
1360+
recalls.push_back(recall);
1361+
}
1362+
1363+
results = SearchResults(std::move(single_results)).WithRecalls(std::move(recalls));
13591364
return Status::OK();
13601365
}
13611366

src/include/milvus/MilvusClientV2.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "request/dml/DeleteRequest.h"
4848
#include "request/dml/InsertRequest.h"
4949
#include "request/dml/UpsertRequest.h"
50+
#include "request/dql/GetRequest.h"
5051
#include "request/dql/HybridSearchRequest.h"
5152
#include "request/dql/QueryIteratorRequest.h"
5253
#include "request/dql/QueryRequest.h"
@@ -689,6 +690,16 @@ class MilvusClientV2 {
689690
virtual Status
690691
Query(const QueryRequest& request, QueryResponse& response) = 0;
691692

693+
/**
694+
* @brief Query with primary keys, and results in a list of records.
695+
*
696+
* @param [in] request input parameters
697+
* @param [out] response output results
698+
* @return Status operation successfully or not
699+
*/
700+
virtual Status
701+
Get(const GetRequest& request, GetResponse& response) = 0;
702+
692703
/**
693704
* @brief Get QueryIterator object based on scalar field(s) by filtering expression.
694705
* Don't disconnect the MilvusClientV2 when the iterator is in using.

0 commit comments

Comments
 (0)