Skip to content

Commit 7351cfd

Browse files
committed
add user defined interceptor
1 parent 981ae0a commit 7351cfd

11 files changed

Lines changed: 158 additions & 96 deletions

File tree

src/auth.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use tonic::service::Interceptor;
22
use tonic::{Request, Status};
33

4+
#[derive(Clone)]
45
pub struct TokenInterceptor {
56
api_key: Option<String>,
67
}

src/qdrant_client/collection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use std::future::Future;
22

33
use tonic::codegen::InterceptedService;
4+
use tonic::service::Interceptor;
45
use tonic::transport::Channel;
56
use tonic::Status;
67

7-
use crate::auth::TokenInterceptor;
88
use crate::qdrant::collections_client::CollectionsClient;
99
use crate::qdrant::{
1010
alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest,
@@ -22,16 +22,16 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
2222
/// configuration.
2323
///
2424
/// Documentation: <https://qdrant.tech/documentation/concepts/collections/>
25-
impl Qdrant {
25+
impl<I: Send + Sync + 'static + Clone + Interceptor> Qdrant<I> {
2626
pub(super) async fn with_collections_client<T, O: Future<Output = Result<T, Status>>>(
2727
&self,
28-
f: impl Fn(CollectionsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
28+
f: impl Fn(CollectionsClient<InterceptedService<Channel, I>>) -> O,
2929
) -> QdrantResult<T> {
3030
let result = self
3131
.channel
3232
.with_channel(
3333
|channel| {
34-
let service = self.with_api_key(channel);
34+
let service = self.with_interceptor(channel);
3535
let mut client =
3636
CollectionsClient::new(service).max_decoding_message_size(usize::MAX);
3737
if let Some(compression) = self.config.compression {

src/qdrant_client/config.rs

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
11
use std::time::Duration;
22

3-
use crate::{Qdrant, QdrantError};
3+
use tonic::service::Interceptor;
4+
5+
use crate::{auth::TokenInterceptor, Qdrant, QdrantError};
6+
7+
struct DefaultConfigValues {
8+
timeout: Duration,
9+
connect_timeout: Duration,
10+
keep_alive_while_idle: bool,
11+
compression: Option<CompressionEncoding>,
12+
check_compatibility: bool,
13+
pool_size: usize,
14+
}
15+
16+
impl Default for DefaultConfigValues {
17+
fn default() -> Self {
18+
Self {
19+
timeout: Duration::from_secs(5),
20+
connect_timeout: Duration::from_secs(5),
21+
keep_alive_while_idle: true,
22+
compression: None,
23+
check_compatibility: true,
24+
pool_size: 3,
25+
}
26+
}
27+
}
428

529
/// Qdrant client configuration
630
///
@@ -17,7 +41,7 @@ use crate::{Qdrant, QdrantError};
1741
/// .build();
1842
/// ```
1943
#[derive(Clone)]
20-
pub struct QdrantConfig {
44+
pub struct QdrantConfig<I: Send + Sync + 'static + Clone + Interceptor = TokenInterceptor> {
2145
/// Qdrant server URI to connect to
2246
pub uri: String,
2347

@@ -30,9 +54,6 @@ pub struct QdrantConfig {
3054
/// Whether to keep idle connections active
3155
pub keep_alive_while_idle: bool,
3256

33-
/// Optional API key or token to use for authorization
34-
pub api_key: Option<String>,
35-
3657
/// Optional compression schema to use for API requests
3758
pub compression: Option<CompressionEncoding>,
3859

@@ -42,51 +63,28 @@ pub struct QdrantConfig {
4263
/// Amount of concurrent connections.
4364
/// If set to 0 or 1, connection pools will be disabled.
4465
pub pool_size: usize,
66+
67+
/// The interceptor to use for modifying requests
68+
pub interceptor: I,
4569
}
4670

47-
impl QdrantConfig {
48-
/// Start configuring a Qdrant client with an URL
49-
///
50-
/// ```rust,no_run
51-
///# use qdrant_client::config::QdrantConfig;
52-
/// let client = QdrantConfig::from_url("http://localhost:6334").build();
53-
/// ```
54-
///
55-
/// This is normally done through [`Qdrant::from_url`](crate::Qdrant::from_url).
56-
pub fn from_url(url: &str) -> Self {
57-
QdrantConfig {
71+
impl<I: Send + Sync + 'static + Clone + Interceptor> QdrantConfig<I> {
72+
fn with_defaults(url: &str, interceptor: I) -> Self {
73+
let defaults = DefaultConfigValues::default();
74+
Self {
5875
uri: url.to_string(),
59-
..Self::default()
76+
timeout: defaults.timeout,
77+
connect_timeout: defaults.connect_timeout,
78+
keep_alive_while_idle: defaults.keep_alive_while_idle,
79+
compression: defaults.compression,
80+
check_compatibility: defaults.check_compatibility,
81+
pool_size: defaults.pool_size,
82+
interceptor,
6083
}
6184
}
6285

63-
/// Set an optional API key
64-
///
65-
/// # Examples
66-
///
67-
/// A typical use case might be getting the key from an environment variable:
68-
///
69-
/// ```rust,no_run
70-
/// use qdrant_client::Qdrant;
71-
///
72-
/// let client = Qdrant::from_url("http://localhost:6334")
73-
/// .api_key(std::env::var("QDRANT_API_KEY"))
74-
/// .build();
75-
/// ```
76-
///
77-
/// Or you might get it from some configuration:
78-
///
79-
/// ```rust,no_run
80-
///# use std::collections::HashMap;
81-
///# let config: HashMap<&str, String> = HashMap::new();
82-
///# use qdrant_client::Qdrant;
83-
/// let client = Qdrant::from_url("http://localhost:6334")
84-
/// .api_key(config.get("api_key"))
85-
/// .build();
86-
/// ```
87-
pub fn api_key(mut self, api_key: impl AsOptionApiKey) -> Self {
88-
self.api_key = api_key.api_key();
89-
self
86+
pub fn from_url_with_interceptor(url: &str, interceptor: I) -> Self {
87+
Self::with_defaults(url, interceptor)
9088
}
9189

9290
/// Keep the connection alive while idle
@@ -138,13 +136,6 @@ impl QdrantConfig {
138136
self
139137
}
140138

141-
/// Set an API key
142-
///
143-
/// Also see [`api_key()`](fn@Self::api_key).
144-
pub fn set_api_key(&mut self, api_key: &str) {
145-
self.api_key = Some(api_key.to_string());
146-
}
147-
148139
/// Set the timeout for this client
149140
///
150141
/// Also see [`timeout()`](fn@Self::timeout).
@@ -174,7 +165,7 @@ impl QdrantConfig {
174165
}
175166

176167
/// Build the configured [`Qdrant`] client
177-
pub fn build(self) -> Result<Qdrant, QdrantError> {
168+
pub fn build(self) -> Result<Qdrant<I>, QdrantError> {
178169
Qdrant::new(self)
179170
}
180171

@@ -190,21 +181,52 @@ impl QdrantConfig {
190181
}
191182
}
192183

184+
impl QdrantConfig<TokenInterceptor> {
185+
/// Start configuring a Qdrant client with an URL
186+
///
187+
/// ```rust,no_run
188+
///# use qdrant_client::config::QdrantConfig;
189+
/// let client = QdrantConfig::from_url("http://localhost:6334").build();
190+
/// ```
191+
///
192+
/// This is normally done through [`Qdrant::from_url`](crate::Qdrant::from_url).
193+
pub fn from_url(url: &str) -> Self {
194+
Self::with_defaults(url, TokenInterceptor::new(None))
195+
}
196+
197+
/// Set an optional API key
198+
///
199+
/// This method is only available when using the default TokenInterceptor.
200+
/// When you set an API key, it automatically configures the TokenInterceptor.
201+
///
202+
/// # Examples
203+
///
204+
/// ```rust,no_run
205+
/// use qdrant_client::Qdrant;
206+
///
207+
/// let client = Qdrant::from_url("http://localhost:6334")
208+
/// .api_key(std::env::var("QDRANT_API_KEY"))
209+
/// .build();
210+
/// ```
211+
pub fn api_key(mut self, api_key: impl AsOptionApiKey) -> Self {
212+
self.interceptor = TokenInterceptor::new(api_key.api_key());
213+
self
214+
}
215+
216+
/// Set an API key
217+
///
218+
/// Also see [`api_key()`](fn@Self::api_key).
219+
pub fn set_api_key(&mut self, api_key: &str) {
220+
self.interceptor = TokenInterceptor::new(Some(api_key.to_string()));
221+
}
222+
}
223+
193224
/// Default Qdrant client configuration.
194225
///
195226
/// Connects to `http://localhost:6334` without an API key.
196-
impl Default for QdrantConfig {
227+
impl Default for QdrantConfig<TokenInterceptor> {
197228
fn default() -> Self {
198-
Self {
199-
uri: String::from("http://localhost:6334"),
200-
timeout: Duration::from_secs(5),
201-
connect_timeout: Duration::from_secs(5),
202-
keep_alive_while_idle: true,
203-
api_key: None,
204-
compression: None,
205-
check_compatibility: true,
206-
pool_size: 3,
207-
}
229+
Self::with_defaults("http://localhost:6334", TokenInterceptor::new(None))
208230
}
209231
}
210232

src/qdrant_client/index.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tonic::service::Interceptor;
2+
13
use crate::qdrant::{
24
CreateFieldIndexCollection, DeleteFieldIndexCollection, PointsOperationResponse,
35
};
@@ -8,7 +10,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
810
/// Manage field and payload indices in collections.
911
///
1012
/// Documentation: <https://qdrant.tech/documentation/concepts/indexing/>
11-
impl Qdrant {
13+
impl<I: Send + Sync + 'static + Clone + Interceptor> Qdrant<I> {
1214
/// Create payload index in a collection.
1315
///
1416
/// ```no_run

src/qdrant_client/mod.rs

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::sync::Arc;
1717
use std::thread;
1818

1919
use tonic::codegen::InterceptedService;
20+
use tonic::service::Interceptor;
2021
use tonic::transport::{Channel, Uri};
2122
use tonic::Status;
2223

@@ -83,9 +84,9 @@ pub type QdrantBuilder = QdrantConfig;
8384
/// - [`upsert_points`](Self::upsert_points) - insert or update points
8485
/// - [`query`](Self::query) - query points with similarity search
8586
#[derive(Clone)]
86-
pub struct Qdrant {
87+
pub struct Qdrant<I: Send + Sync + 'static + Clone + Interceptor = TokenInterceptor> {
8788
/// Client configuration
88-
pub config: QdrantConfig,
89+
pub config: QdrantConfig<I>,
8990

9091
/// Internal connection pool
9192
channel: Arc<ChannelPool>,
@@ -94,11 +95,11 @@ pub struct Qdrant {
9495
/// # Construct and connect
9596
///
9697
/// Methods to construct a new Qdrant client.
97-
impl Qdrant {
98+
impl<I: Send + Sync + 'static + Clone + Interceptor> Qdrant<I> {
9899
/// Create a new Qdrant client.
99100
///
100101
/// Constructs the client and connects based on the given [`QdrantConfig`](config::QdrantConfig).
101-
pub fn new(config: QdrantConfig) -> QdrantResult<Self> {
102+
pub fn new(config: QdrantConfig<I>) -> QdrantResult<Self> {
102103
if config.check_compatibility {
103104
// create a temporary client to check compatibility
104105
let channel = ChannelPool::new(
@@ -162,38 +163,48 @@ impl Qdrant {
162163
Ok(client)
163164
}
164165

165-
/// Build a new Qdrant client with the given URL.
166+
/// Build a new Qdrant client with the given URL and custom interceptor.
166167
///
167168
/// ```no_run
168169
/// use qdrant_client::Qdrant;
170+
/// use tonic::service::Interceptor;
171+
/// use tonic::{Request, Status};
172+
///
173+
/// #[derive(Clone)]
174+
/// struct CustomInterceptor;
175+
/// impl Interceptor for CustomInterceptor {
176+
/// fn call(&mut self, req: Request<()>) -> Result<Request<()>, Status> {
177+
/// Ok(req)
178+
/// }
179+
/// }
169180
///
170181
///# async fn connect() -> Result<(), qdrant_client::QdrantError> {
171-
/// let client = Qdrant::from_url("http://localhost:6334").build()?;
182+
/// let client = Qdrant::from_url_with_interceptor(
183+
/// "http://localhost:6334",
184+
/// CustomInterceptor
185+
/// ).build()?;
172186
///# Ok(())
173187
///# }
174188
/// ```
175-
///
176-
/// See more ways to set up the client [here](Self#set-up).
177-
pub fn from_url(url: &str) -> QdrantBuilder {
178-
QdrantBuilder::from_url(url)
189+
pub fn from_url_with_interceptor(url: &str, interceptor: I) -> QdrantConfig<I> {
190+
QdrantConfig::<I>::from_url_with_interceptor(url, interceptor)
179191
}
180192

181-
/// Wraps a channel with a token interceptor
182-
fn with_api_key(&self, channel: Channel) -> InterceptedService<Channel, TokenInterceptor> {
183-
let interceptor = TokenInterceptor::new(self.config.api_key.clone());
184-
InterceptedService::new(channel, interceptor)
193+
/// Wraps a channel with the configured interceptor
194+
fn with_interceptor(&self, channel: Channel) -> InterceptedService<Channel, I> {
195+
InterceptedService::new(channel, self.config.interceptor.clone())
185196
}
186197

187198
// Access to raw root qdrant API
188199
async fn with_root_qdrant_client<T, O: Future<Output = Result<T, Status>>>(
189200
&self,
190-
f: impl Fn(qdrant_client::QdrantClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
201+
f: impl Fn(qdrant_client::QdrantClient<InterceptedService<Channel, I>>) -> O,
191202
) -> QdrantResult<T> {
192203
let result = self
193204
.channel
194205
.with_channel(
195206
|channel| {
196-
let service = self.with_api_key(channel);
207+
let service = self.with_interceptor(channel);
197208
let mut client = qdrant_client::QdrantClient::new(service)
198209
.max_decoding_message_size(usize::MAX);
199210
if let Some(compression) = self.config.compression {
@@ -229,3 +240,21 @@ impl Qdrant {
229240
.await
230241
}
231242
}
243+
244+
impl Qdrant<TokenInterceptor> {
245+
/// Build a new Qdrant client with the given URL.
246+
///
247+
/// ```no_run
248+
/// use qdrant_client::Qdrant;
249+
///
250+
///# async fn connect() -> Result<(), qdrant_client::QdrantError> {
251+
/// let client = Qdrant::from_url("http://localhost:6334").build()?;
252+
///# Ok(())
253+
///# }
254+
/// ```
255+
///
256+
/// See more ways to set up the client [here](Self#set-up).
257+
pub fn from_url(url: &str) -> QdrantBuilder {
258+
QdrantBuilder::from_url(url)
259+
}
260+
}

src/qdrant_client/payload.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tonic::service::Interceptor;
2+
13
use crate::qdrant::{
24
ClearPayloadPoints, DeletePayloadPoints, PointsOperationResponse, SetPayloadPoints,
35
};
@@ -8,7 +10,7 @@ use crate::qdrant_client::{Qdrant, QdrantResult};
810
/// Manage point payloads.
911
///
1012
/// Documentation: <https://qdrant.tech/documentation/concepts/payload/>
11-
impl Qdrant {
13+
impl<I: Send + Sync + 'static + Clone + Interceptor> Qdrant<I> {
1214
/// Set payload of points.
1315
///
1416
/// Sets only the given payload values on a point, leaving other existing payloads in place.

0 commit comments

Comments
 (0)