Skip to content

Commit 5aad5d3

Browse files
committed
fix cache + enhance RespBuf struct with Eq and Hash traits
1 parent 18ccdf8 commit 5aad5d3

2 files changed

Lines changed: 81 additions & 67 deletions

File tree

src/cache.rs

Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ use crate::{
88
StringCommands, ZRangeOptions,
99
},
1010
resp::{
11-
BulkString, Command, CommandArgsMut, RespBuf, RespDeserializer, RespSerializer, Response,
12-
Value, cmd,
11+
BulkString, Command, CommandArgsMut, FastPathCommandBuilder, RespDeserializer,
12+
RespResponse, Response,
1313
},
1414
};
15-
use bytes::BytesMut;
15+
use bytes::Bytes;
1616
use dashmap::DashMap;
1717
use futures_util::StreamExt;
18-
use serde::{Deserialize, Serialize, de::DeserializeOwned};
19-
use std::{fmt::Write, sync::Arc, time::Duration};
18+
use serde::{Serialize, de::DeserializeOwned};
19+
use std::{sync::Arc, time::Duration};
2020

2121
/// Re-export the moka cache builder.
2222
pub use moka::future::CacheBuilder;
2323

24-
type SubCache = DashMap<Command, RespBuf>;
24+
type SubCache = DashMap<Bytes, RespResponse>;
2525
type MokaCache = moka::future::Cache<BulkString, Arc<SubCache>>;
2626
type MokaCacheBuilder = moka::future::CacheBuilder<BulkString, Arc<SubCache>, MokaCache>;
2727

@@ -134,65 +134,78 @@ impl Cache {
134134
/// Executes the `MGET` command with client-side caching.
135135
pub async fn mget<R: Response + DeserializeOwned>(&self, keys: impl Serialize) -> Result<R> {
136136
let prepared_command = self.client.mget::<R>(keys);
137-
let mut collection_buf = BytesMut::new();
138-
let _ =
139-
collection_buf.write_fmt(format_args!("*{}\r\n", prepared_command.command.num_args()));
140-
141-
for arg in (0..prepared_command.command.num_args())
142-
.filter_map(|i| prepared_command.command.get_arg(i))
143-
{
144-
let key = BulkString::from(arg.to_vec());
145-
146-
let Some(values) = self.cache.get(&key).await else {
147-
collection_buf.clear();
148-
break;
149-
};
150-
151-
let prepared_command = self.client.get::<R>(arg);
152-
let Some(buf) = values.get(&prepared_command.command) else {
153-
collection_buf.clear();
154-
break;
155-
};
156-
157-
collection_buf.extend(buf.iter());
137+
let mut responses = Vec::with_capacity(prepared_command.command.num_args());
138+
let mut missing_indices = Vec::new();
139+
let mut missing_keys = Vec::new();
140+
141+
// 1. check cache
142+
for (i, arg) in prepared_command.command.args().enumerate() {
143+
let key = BulkString::from(arg.clone());
144+
145+
if let Some(values) = self.cache.get(&key).await
146+
&& let Some(response) = values.get(FastPathCommandBuilder::get(key.clone()).bytes())
147+
{
148+
log::debug!(
149+
"[{}] Cache hit on key `{}`",
150+
self.client.connection_tag(),
151+
key
152+
);
153+
responses.push(response.clone());
154+
} else {
155+
log::debug!(
156+
"[{}] Cache miss on key `{}`",
157+
self.client.connection_tag(),
158+
key
159+
);
160+
responses.push(RespResponse::null());
161+
missing_indices.push(i);
162+
missing_keys.push(key);
163+
}
158164
}
159165

160-
if !collection_buf.is_empty() {
161-
log::debug!("[{}] Cache hit on mget", self.client.connection_tag(),);
166+
// 2. Fetch missing keys from Redis server if any
167+
if !missing_keys.is_empty() {
168+
let missing_prepared_command = self.client.mget::<R>(missing_keys);
169+
let response = self
170+
.client
171+
.internal_send(missing_prepared_command.command, None)
172+
.await?;
173+
let Ok(array_iter) = response.clone().into_array_iter() else {
174+
return Err(Error::Client(ClientError::ExpectedArrayForMGet));
175+
};
162176

163-
let mut deserializer = RespDeserializer::new(&collection_buf);
164-
return R::deserialize(&mut deserializer);
165-
}
177+
for (idx_in_missing, response) in array_iter.enumerate() {
178+
let original_idx = missing_indices[idx_in_missing];
179+
180+
let Some(key) = prepared_command
181+
.command
182+
.get_arg(original_idx)
183+
.map(BulkString::from)
184+
else {
185+
break;
186+
};
187+
188+
// Insert into cache
189+
self.cache
190+
.entry(key.clone())
191+
.or_insert_with(async { Arc::new(DashMap::new()) })
192+
.await
193+
.value()
194+
.insert(
195+
FastPathCommandBuilder::get(key).bytes().clone(),
196+
response.clone(),
197+
);
166198

167-
let buf = self
168-
.client
169-
.send(prepared_command.command.clone(), None)
170-
.await?;
171-
let mut deserializer = RespDeserializer::new(&buf);
172-
let Value::Array(values) = Value::deserialize(&mut deserializer)? else {
173-
return Err(Error::Client(ClientError::ExpectedArrayForMGet));
174-
};
175-
176-
for (value, key) in values.iter().zip(
177-
(0..prepared_command.command.num_args())
178-
.filter_map(|i| prepared_command.command.get_arg(i)),
179-
) {
180-
let mut serializer = RespSerializer::new();
181-
value.serialize(&mut serializer)?;
182-
183-
// Insert into cache
184-
self.cache
185-
.entry(key.to_vec().into())
186-
.or_insert_with(async { Arc::new(DashMap::new()) })
187-
.await
188-
.value()
189-
.insert(
190-
cmd("GET").arg(key).into(),
191-
RespBuf::new(serializer.get_output().into()),
192-
);
199+
responses[original_idx] = response;
200+
}
201+
} else {
202+
log::debug!("[{}] Cache hit on mget", self.client.connection_tag());
193203
}
194204

195-
R::deserialize(&Value::Array(values))
205+
// 3. deserialize
206+
let response = RespResponse::owned_array(responses);
207+
let deserializer = RespDeserializer::new(response.view());
208+
R::deserialize(deserializer)
196209
}
197210

198211
/// Executes the `GETRANGE` command with client-side caching.
@@ -458,15 +471,15 @@ impl Cache {
458471
R: Response + DeserializeOwned,
459472
{
460473
if let Some(values) = self.cache.get(&key).await
461-
&& let Some(buf) = values.get(&command)
474+
&& let Some(response) = values.get(command.bytes())
462475
{
463476
log::debug!(
464477
"[{}] Cache hit on key `{}`",
465478
self.client.connection_tag(),
466479
key
467480
);
468-
let mut deserializer = RespDeserializer::new(&buf);
469-
return R::deserialize(&mut deserializer);
481+
let deserializer = RespDeserializer::new(response.view());
482+
return R::deserialize(deserializer);
470483
}
471484

472485
// Cache miss: fetch from Redis
@@ -476,17 +489,18 @@ impl Cache {
476489
key
477490
);
478491

479-
let buf = self.client.send(command.clone(), None).await?;
480-
let mut deserializer = RespDeserializer::new(&buf);
481-
let deserialized = R::deserialize(&mut deserializer)?;
492+
let command_bytes = command.bytes().clone();
493+
let response = self.client.internal_send(command, None).await?;
494+
let deserializer = RespDeserializer::new(response.view());
495+
let deserialized = R::deserialize(deserializer)?;
482496

483497
// Insert into cache
484498
self.cache
485499
.entry(key)
486500
.or_insert_with(async { Arc::new(DashMap::new()) })
487501
.await
488502
.value()
489-
.insert(command, buf);
503+
.insert(command_bytes, response);
490504

491505
Ok(deserialized)
492506
}

src/resp/resp_buf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use serde::de::DeserializeOwned;
1010
use std::{fmt, ops::Deref};
1111

1212
/// Represents a [RESP](https://redis.io/docs/reference/protocol-spec/) Buffer incoming from the network
13-
#[derive(Clone, Default, PartialEq)]
13+
#[derive(Clone, Default, PartialEq, Eq, Hash)]
1414
pub struct RespBuf(Bytes);
1515

1616
impl RespBuf {

0 commit comments

Comments
 (0)