-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlib.rs
More file actions
119 lines (112 loc) · 4.34 KB
/
lib.rs
File metadata and controls
119 lines (112 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! # datafusion-vector-search-ext
//!
//! A DataFusion extension that integrates USearch HNSW approximate nearest
//! neighbour (ANN) search as a first-class SQL operator.
//!
//! Queries that match the `ORDER BY distance_fn(...) LIMIT k` pattern are
//! transparently rewritten by an optimizer rule into a native USearch index
//! call — no query rewrite required from the caller. WHERE clause filters
//! are handled adaptively: high-selectivity filters use in-graph predicate
//! filtering; low-selectivity filters bypass HNSW entirely and run an exact
//! brute-force search over the valid subset.
//!
//! # Quick setup
//!
//! ```rust,ignore
//! use std::sync::Arc;
//! use datafusion::execution::context::SessionStateBuilder;
//! use datafusion::prelude::SessionContext;
//! use datafusion_vector_search_ext::{
//! USearchIndexConfig, USearchRegistry, USearchQueryPlanner, register_all,
//! };
//! use usearch::MetricKind;
//!
//! // 1. Build or load the index.
//! let cfg = USearchIndexConfig::new(768, MetricKind::L2sq);
//! let index = cfg.load_index("my_table.index")?;
//!
//! // 2. Wrap your table in a PointLookupProvider.
//! let provider = Arc::new(MyTableProvider::new(...));
//!
//! // 3. Register with USearchRegistry.
//! let mut registry = USearchRegistry::new();
//! registry.add("my_table", Arc::new(index), provider.clone(), "id", MetricKind::L2sq)?;
//! let registry = registry.into_arc();
//!
//! // 4. Build SessionContext with the custom query planner.
//! let state = SessionStateBuilder::new()
//! .with_query_planner(Arc::new(USearchQueryPlanner::new(registry.clone())))
//! .build();
//! let ctx = SessionContext::new_with_state(state);
//!
//! // 5. Register UDFs, UDTF, and optimizer rule.
//! register_all(&ctx, registry)?;
//!
//! // Also register your table so DataFusion can resolve column names.
//! ctx.register_table("my_table", provider)?;
//! ```
//!
//! Queries now use the HNSW index automatically:
//!
//! ```sql
//! SELECT id, l2_distance(vector, ARRAY[...]) AS dist
//! FROM my_table
//! WHERE category = 'nlp'
//! ORDER BY dist ASC
//! LIMIT 10
//! ```
pub mod keys;
pub mod lookup;
pub mod node;
pub mod planner;
pub mod registry;
pub mod rule;
pub mod udf;
pub mod udtf;
#[cfg(feature = "parquet-provider")]
pub mod parquet_provider;
#[cfg(feature = "sqlite-provider")]
pub mod sqlite_provider;
pub use keys::{DatasetLayout, pack_key, unpack_key};
pub use lookup::{HashKeyProvider, PointLookupProvider};
pub use node::{DistanceType, USearchNode};
pub use planner::{USearchExec, USearchExecPlanner, USearchQueryPlanner};
pub use registry::{
RegisteredTable, USearchIndexConfig, USearchRegistry, USearchTableConfig, VectorIndexMeta,
VectorIndexResolver,
};
pub use rule::USearchRule;
pub use udf::{cosine_distance_udf, l2_distance_udf, negative_dot_product_udf};
pub use udtf::VectorSearchVectorUDTF;
#[cfg(feature = "parquet-provider")]
pub use parquet_provider::ParquetLookupProvider;
#[cfg(feature = "sqlite-provider")]
pub use sqlite_provider::SqliteLookupProvider;
use std::sync::Arc;
use datafusion::common::Result;
use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::SessionContext;
/// Register all extension components with a DataFusion [`SessionContext`].
///
/// Registers:
/// - `l2_distance(col, query)` — squared Euclidean distance (L2sq)
/// - `cosine_distance(col, query)` — cosine distance
/// - `negative_dot_product(col, query)` — negated inner product
/// - `vector_search_vector('conn.schema.table', 'column', ARRAY[...], k)`
/// — explicit ANN table function returning full rows + `_distance`
/// (cache-only for async-backed resolvers; does not trigger async loads)
/// - [`USearchRule`] — optimizer rewrite rule
///
/// The [`USearchQueryPlanner`] must be installed at `SessionState` build time
/// (before this call) via `SessionStateBuilder::with_query_planner`.
pub fn register_all(ctx: &SessionContext, registry: Arc<dyn VectorIndexResolver>) -> Result<()> {
ctx.register_udf(ScalarUDF::new_from_impl(l2_distance_udf()));
ctx.register_udf(ScalarUDF::new_from_impl(cosine_distance_udf()));
ctx.register_udf(ScalarUDF::new_from_impl(negative_dot_product_udf()));
ctx.register_udtf(
"vector_search_vector",
Arc::new(VectorSearchVectorUDTF::new(registry.clone())),
);
ctx.add_optimizer_rule(Arc::new(USearchRule::new(registry)));
Ok(())
}