Skip to content

Commit 2d4b618

Browse files
Marlon Costayhmo
authored andcommitted
feat(query): make hybrid_search ranker optional with RRF default
1 parent 5a4a3aa commit 2d4b618

1 file changed

Lines changed: 41 additions & 39 deletions

File tree

src/query.rs

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
//! let req2 = AnnSearchRequest::new(vec![vector2], "field2".to_string(), params2, 10);
3131
//! let ranker = WeightedRanker::new(vec![0.7, 0.3]);
3232
//!
33-
//! let results = client.hybrid_search("my_collection", vec![req1, req2], Box::new(ranker), None).await?;
33+
//! let results = client.hybrid_search("my_collection", vec![req1, req2], Some(Box::new(ranker)), None).await?;
3434
//! ```
3535
3636
use std::collections::HashMap;
@@ -387,7 +387,7 @@ pub type HybridSearchOptions = SearchOptions;
387387
/// .limit(100)
388388
/// .offset(0);
389389
/// ```
390-
#[derive(Debug, Clone)]
390+
#[derive(Debug, Clone, Default)]
391391
pub struct QueryOptions {
392392
output_fields: Vec<String>,
393393
partition_names: Vec<String>,
@@ -414,20 +414,6 @@ pub enum IdType {
414414
VarChar(Vec<String>),
415415
}
416416

417-
impl Default for QueryOptions {
418-
fn default() -> Self {
419-
Self {
420-
output_fields: Vec::new(),
421-
partition_names: Vec::new(),
422-
guarantee_timestamp: 0,
423-
query_params: vec![],
424-
consistency_level: 0,
425-
use_default_consistency: false,
426-
expr_template_values: HashMap::new(),
427-
}
428-
}
429-
}
430-
431417
impl QueryOptions {
432418
/// Creates a new QueryOptions instance with default values
433419
///
@@ -832,6 +818,7 @@ pub struct SearchOptions {
832818
pub(crate) search_params: Vec<KeyValuePair>,
833819
pub(crate) partition_names: Vec<String>,
834820
pub(crate) anns_field: Vec<String>,
821+
/// Ranker for hybrid search; can be set here or passed to `hybrid_search()`.
835822
pub(crate) ranker: Option<Box<dyn BaseRanker>>,
836823
pub(crate) expr_template_values: HashMap<String, proto::schema::TemplateValue>,
837824
pub(crate) other_params: Option<Vec<KeyValuePair>>,
@@ -902,6 +889,20 @@ impl SearchOptions {
902889
Self::default().partitions(partitions)
903890
}
904891

892+
/// Sets the ranker for hybrid search (alternative to passing ranker to `hybrid_search()`).
893+
///
894+
/// # Arguments
895+
///
896+
/// * `ranker` - Ranker implementation (e.g. `WeightedRanker`, `RrfRanker`)
897+
///
898+
/// # Returns
899+
///
900+
/// Self for method chaining
901+
pub fn ranker(mut self, ranker: Box<dyn BaseRanker>) -> Self {
902+
self.ranker = Some(ranker);
903+
self
904+
}
905+
905906
/// Adds radius parameter for range search
906907
///
907908
/// # Arguments
@@ -1297,12 +1298,7 @@ impl Client {
12971298
nq: data.len() as _,
12981299
placeholder_group: get_place_holder_group(&data)?,
12991300
dsl_type: DslType::BoolExprV1 as _,
1300-
output_fields: options
1301-
.output_fields
1302-
.clone()
1303-
.into_iter()
1304-
.map(|f| f.into())
1305-
.collect(),
1301+
output_fields: options.output_fields.clone(),
13061302
search_params,
13071303
travel_timestamp: 0,
13081304
guarantee_timestamp: self
@@ -1386,7 +1382,7 @@ impl Client {
13861382
///
13871383
/// * `collection_name` - Name of the collection to search
13881384
/// * `reqs` - Vector of ANN search requests
1389-
/// * `ranker` - Ranking algorithm to combine results
1385+
/// * `ranker` - Optional ranking algorithm; if `None`, uses options.ranker or default RRF (k=60)
13901386
/// * `options` - Optional search configuration
13911387
///
13921388
/// # Returns
@@ -1418,19 +1414,25 @@ impl Client {
14181414
/// );
14191415
///
14201416
/// let ranker = WeightedRanker::new(vec![0.7, 0.3]);
1421-
/// let results = client.hybrid_search("my_collection", vec![req1, req2], Box::new(ranker), None).await?;
1417+
/// let results = client.hybrid_search("my_collection", vec![req1, req2], Some(Box::new(ranker)), None).await?;
1418+
/// // Or pass ranker via options; if omitted, RRF with k=60 is used
1419+
/// let options = SearchOptions::new().ranker(Box::new(WeightedRanker::new(vec![0.7, 0.3])));
1420+
/// let results = client.hybrid_search("my_collection", vec![req1, req2], None, Some(options)).await?;
14221421
/// ```
14231422
pub async fn hybrid_search<S>(
14241423
&self,
14251424
collection_name: S,
14261425
reqs: Vec<AnnSearchRequest>,
1427-
ranker: Box<dyn BaseRanker>,
1426+
ranker: Option<Box<dyn BaseRanker>>,
14281427
options: Option<HybridSearchOptions>,
14291428
) -> Result<Vec<SearchResult<'_>>>
14301429
where
14311430
S: Into<String>,
14321431
{
14331432
let options = options.unwrap_or_default();
1433+
let effective_ranker = ranker
1434+
.or(options.ranker)
1435+
.unwrap_or_else(|| Box::new(RrfRanker::new(60.0)) as Box<dyn BaseRanker>);
14341436
let collection_name = collection_name.into();
14351437
let collection = self.collection_cache.get(&collection_name).await?;
14361438

@@ -1520,7 +1522,7 @@ impl Client {
15201522
}
15211523

15221524
// Prepare ranker parameters
1523-
let rank_params = prepare_rank_params(&vec![], ranker.get_params());
1525+
let rank_params = prepare_rank_params(&[], effective_ranker.get_params());
15241526

15251527
// Create HybridSearchRequest
15261528
let request = proto::milvus::HybridSearchRequest {
@@ -1654,7 +1656,7 @@ impl Client {
16541656
if data_type == DataType::VarChar {
16551657
let ids: Vec<String> = pks.iter().map(|entry| format!("'{}'", entry)).collect();
16561658
let expr = format!("{pk_field_name} in {:?}", ids);
1657-
return Ok(expr);
1659+
Ok(expr)
16581660
} else {
16591661
let mut ids: Vec<i64> = Vec::new();
16601662
for (i, entry) in pks.iter().enumerate() {
@@ -1669,7 +1671,7 @@ impl Client {
16691671
}
16701672
}
16711673
let expr = format!("{pk_field_name} in {:?}", ids);
1672-
return Ok(expr);
1674+
Ok(expr)
16731675
}
16741676
}
16751677

@@ -1721,7 +1723,7 @@ impl Client {
17211723
IdType::VarChar(ids_string) => ids_string,
17221724
};
17231725

1724-
let ids: Vec<String> = ids.into_iter().map(|x| x.into()).collect();
1726+
let ids: Vec<String> = ids.into_iter().collect();
17251727

17261728
//If ids is empty,return an empty vec
17271729
if ids.is_empty() {
@@ -1731,7 +1733,7 @@ impl Client {
17311733
let collection = self.collection_cache.get(&collection_name).await?;
17321734
let expr = self.pack_pks_expr(&collection, ids)?;
17331735
let option = options.unwrap_or_default();
1734-
Ok(self.query(collection_name, expr.as_str(), &option).await?)
1736+
self.query(collection_name, expr.as_str(), &option).await
17351737
}
17361738
}
17371739

@@ -1759,7 +1761,7 @@ pub fn get_place_holder_group(vectors: &Vec<Value>) -> Result<Vec<u8>> {
17591761
group.encode(&mut buf).map_err(|e| {
17601762
SuperError::Unexpected(format!("Failed to encode placeholder group: {}", e))
17611763
})?;
1762-
return Ok(buf.to_vec());
1764+
Ok(buf.to_vec())
17631765
}
17641766

17651767
/// Converts vector data to placeholder value format
@@ -1785,9 +1787,9 @@ fn get_place_holder_value(vectors: &Vec<Value>) -> Result<PlaceholderValue> {
17851787
values: Vec::new(),
17861788
};
17871789
// if no vectors, return an empty one
1788-
if vectors.len() == 0 {
1790+
if vectors.is_empty() {
17891791
return Ok(place_holder);
1790-
};
1792+
}
17911793

17921794
match vectors[0] {
17931795
Value::FloatArray(_) => place_holder.r#type = PlaceholderType::FloatVector as _,
@@ -1818,7 +1820,7 @@ fn get_place_holder_value(vectors: &Vec<Value>) -> Result<PlaceholderValue> {
18181820
}
18191821
};
18201822
}
1821-
return Ok(place_holder);
1823+
Ok(place_holder)
18221824
}
18231825

18241826
/// Extracts a parameter value from search parameters with a default fallback
@@ -1835,7 +1837,7 @@ fn get_place_holder_value(vectors: &Vec<Value>) -> Result<PlaceholderValue> {
18351837
/// # Returns
18361838
///
18371839
/// Parameter value as string, or default value if not found
1838-
fn extract_param(search_params: &Vec<KeyValuePair>, key: &str, default: &str) -> String {
1840+
fn extract_param(search_params: &[KeyValuePair], key: &str, default: &str) -> String {
18391841
search_params
18401842
.iter()
18411843
.find(|param| param.key == key)
@@ -1904,15 +1906,15 @@ fn get_params(search_params: &Vec<KeyValuePair>) -> String {
19041906
///
19051907
/// Combined rank parameters with defaults and optional parameters
19061908
fn prepare_rank_params(
1907-
search_params: &Vec<KeyValuePair>,
1909+
search_params: &[KeyValuePair],
19081910
rank_params: Vec<KeyValuePair>,
19091911
) -> Vec<KeyValuePair> {
19101912
let mut final_rank_params = rank_params;
19111913

19121914
// Parameters with default values
1913-
let limit = extract_param(&search_params, "limit", "10");
1914-
let round_decimal = extract_param(&search_params, "round_decimal", "-1");
1915-
let offset = extract_param(&search_params, "offset", "0");
1915+
let limit = extract_param(search_params, "limit", "10");
1916+
let round_decimal = extract_param(search_params, "round_decimal", "-1");
1917+
let offset = extract_param(search_params, "offset", "0");
19161918

19171919
final_rank_params.push(KeyValuePair {
19181920
key: "limit".to_string(),

0 commit comments

Comments
 (0)