Skip to content

Commit ffa9599

Browse files
Istrate Andrei-EduardIstrate Andrei-Eduard
authored andcommitted
FIX - Snapshot Memory Pressure: Both save_to_cold and
start_snapshot_task clone the entire database state to serialize it. This effectively doubles memory consumption during the operation and can lead to OOM on large datasets.
1 parent 8f94d04 commit ffa9599

3 files changed

Lines changed: 129 additions & 59 deletions

File tree

src/network/broadcaster.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::collections::HashMap;
22
use std::env;
33
use std::time::Duration;
44
use reqwest;
5-
use actix_web::{web, HttpResponse, Responder};
5+
use actix_web::{HttpRequest, HttpResponse, Responder, web};
66
use serde_json::json;
77
use crate::storage::engine::{
88
current_timestamp, ClusterData, NodeInfo, NodeStatus,
@@ -181,18 +181,50 @@ pub async fn heartbeat() -> impl Responder {
181181
}))
182182
}
183183

184+
// Define a safe upper bound for your cluster size
185+
const MAX_CLUSTER_SIZE: usize = 100;
186+
184187
/// Process membership updates received from other nodes
185188
pub async fn update_membership(
189+
req: HttpRequest, // Added to inspect headers
186190
cluster_data: web::Data<ClusterData>,
187191
payload: web::Json<HashMap<String, NodeInfo>>,
188192
) -> impl Responder {
193+
// 1. Authentication Check
194+
// In production, load this once at startup rather than on every request
195+
let expected_token = env::var("CLUSTER_SECRET")
196+
.unwrap_or_else(|_| "default_insecure_secret".to_string());
197+
198+
let is_authenticated = req.headers()
199+
.get("X-Cluster-Token")
200+
.and_then(|h| h.to_str().ok())
201+
.map_or(false, |token| token == expected_token);
202+
203+
if !is_authenticated {
204+
return HttpResponse::Unauthorized().json(json!({
205+
"error": "Unauthorized: Invalid or missing cluster token"
206+
}));
207+
}
208+
189209
let mut local_guard = cluster_data.nodes.write().await;
190210
let incoming = payload.into_inner();
191211
let mut updated = false;
192212

193213
for (node, info) in incoming.into_iter() {
194-
// For nodes not in our membership, add them
214+
// 2. Input Validation (Basic check to ensure it looks like a URL)
215+
if !node.starts_with("http://") && !node.starts_with("https://") {
216+
continue; // Ignore malformed node addresses
217+
}
218+
219+
// For nodes not in our membership, add them safely
195220
if !local_guard.contains_key(&node) {
221+
// 3. Hard Limit Check
222+
if local_guard.len() >= MAX_CLUSTER_SIZE {
223+
// Silently ignore or log a warning. Avoid heavy logging to prevent log-flooding DoS.
224+
eprintln!("Warning: Cluster at max capacity ({}). Ignoring new node: {}", MAX_CLUSTER_SIZE, node);
225+
continue;
226+
}
227+
196228
local_guard.insert(node, info);
197229
updated = true;
198230
continue;
@@ -212,7 +244,6 @@ pub async fn update_membership(
212244

213245
HttpResponse::Ok().finish()
214246
}
215-
216247
/// Get the current cluster membership state
217248
pub async fn get_membership(
218249
cluster_data: web::Data<ClusterData>,

src/storage/engine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ pub async fn get_value(
208208
// 2️⃣ Authenticate external user via JWT
209209
let user = match extract_user_from_token(&req) {
210210
Ok(u) => u,
211-
Err(resp) => return HttpResponse::Unauthorized().body("You are not owning this record"),
211+
Err(_) => return HttpResponse::Unauthorized().body("You are not owning this record"),
212212
};
213213

214214
// 3️⃣ Request replicas (including local node) to find the latest version

src/storage/persistance.rs

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::HashMap;
22
use std::fs::{create_dir_all, File, read_dir};
3-
use std::io::{self, BufRead, BufReader, Read, Write};
3+
use std::io::{self, BufRead, BufReader, Write};
44
use std::path::Path;
55
use std::process;
66
use std::time::Duration;
@@ -108,59 +108,85 @@ pub struct WalEntry {
108108
// }
109109

110110
/// Saves the current in‑memory state (snapshot) for each table to disk.
111-
/// Also clears the WAL after snapshot.
111+
/// Processes one table at a time, and encrypts data in 1MB chunks to prevent OOM.
112112
pub async fn save_to_cold(state: web::Data<AppState>) -> io::Result<()> {
113-
// Clone the store so that we can release the lock.
114-
let store_snapshot = {
113+
let table_names: Vec<String> = {
115114
let store = state.store.read().await;
116-
store.clone()
115+
store.keys().cloned().collect()
117116
};
118-
let state_clone = state.clone();
119-
task::spawn_blocking(move || {
120-
for (table_name, table_data) in store_snapshot.into_iter() {
121-
let folder_path = Path::new(state_clone.base_dir).join(&table_name);
117+
118+
for table_name in table_names {
119+
// Grab the read lock just long enough to clone this specific table
120+
let table_data = {
121+
let store = state.store.read().await;
122+
match store.get(&table_name) {
123+
Some(data) => data.clone(),
124+
None => continue,
125+
}
126+
};
127+
128+
let state_clone = state.clone();
129+
let t_name = table_name.clone();
130+
131+
task::spawn_blocking(move || -> io::Result<()> {
132+
let folder_path = Path::new(state_clone.base_dir).join(&t_name);
122133
create_dir_all(&folder_path)?;
123134
let file_path = folder_path.join("storage.db");
124135
let mut file = File::create(&file_path)?;
125136

126-
// build your plaintext
127-
let mut plain_data = String::new();
128-
for (key, versioned_value) in table_data.iter() {
129-
plain_data.push_str(&format!(
130-
"{} = {}\n",
131-
key,
132-
serde_json::to_string(versioned_value)?
133-
));
134-
}
137+
let mut plain_buffer = String::new();
138+
// Define a chunk size limit (e.g., 1 MB)
139+
let chunk_size_limit = 1024 * 1024;
140+
141+
// Use into_iter() to consume the cloned table and free memory as we go
142+
for (key, versioned_value) in table_data.into_iter() {
143+
let json_str = serde_json::to_string(&versioned_value)
144+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
145+
146+
plain_buffer.push_str(&format!("{} = {}\n", key, json_str));
135147

136-
// ✂️ unwrap the Result<String,_> here and map to io::Error
137-
let encrypted_data = encrypt(&plain_data)
138-
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
148+
// If our buffer hits the limit, encrypt, write, and clear
149+
if plain_buffer.len() >= chunk_size_limit {
150+
let encrypted_chunk = encrypt(&plain_buffer)
151+
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
152+
153+
// Write the base64 string followed by a newline
154+
writeln!(file, "{}", encrypted_chunk)?;
155+
plain_buffer.clear();
156+
}
157+
}
139158

140-
// write the actual String, not a Result<_,_>
141-
file.write_all(encrypted_data.as_bytes())?;
159+
// Flush any remaining records in the buffer
160+
if !plain_buffer.is_empty() {
161+
let encrypted_chunk = encrypt(&plain_buffer)
162+
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
163+
writeln!(file, "{}", encrypted_chunk)?;
164+
}
142165

143-
// Clear the WAL
166+
// Clear the WAL for this table
144167
let wal_path = folder_path.join("wal.log");
145168
File::create(&wal_path)?;
146-
}
147-
Ok::<(), io::Error>(())
148-
})
169+
170+
Ok(())
171+
})
149172
.await??;
173+
}
150174

151175
Ok(())
152176
}
153-
154177
/// Loads all tables from disk by replaying the snapshot and WAL.
155178
pub async fn load_all_tables(state: &web::Data<AppState>) -> io::Result<()> {
156179
let state_cloned = state.clone(); // clone to ensure 'static lifetime in blocking task
180+
157181
let new_store = task::spawn_blocking(move || {
158-
let base_path = Path::new(state_cloned.base_dir);
182+
let base_path = Path::new(&state_cloned.base_dir);
159183
let mut store: HashMap<String, HashMap<String, VersionedValue>> = HashMap::new();
184+
160185
if base_path.exists() && base_path.is_dir() {
161186
for entry in read_dir(base_path)? {
162187
let entry = entry?;
163188
let table_folder = entry.path();
189+
164190
if table_folder.is_dir() {
165191
let table_name = match table_folder.file_name() {
166192
Some(name) => name.to_string_lossy().to_string(),
@@ -169,44 +195,56 @@ pub async fn load_all_tables(state: &web::Data<AppState>) -> io::Result<()> {
169195

170196
let mut table_data: HashMap<String, VersionedValue> = HashMap::new();
171197
let snapshot_path = table_folder.join("storage.db");
198+
199+
// --- 1. Load Snapshot via Chunked Decryption ---
172200
if snapshot_path.exists() {
173-
let mut file = File::open(&snapshot_path)?;
174-
let mut encrypted_content = String::new();
175-
file.read_to_string(&mut encrypted_content)?;
176-
match decrypt(&encrypted_content) {
177-
Ok(decrypted_content) => {
178-
for line in decrypted_content.lines() {
179-
if let Some((key, json_str)) = line.split_once('=') {
180-
let key = key.trim().to_string();
181-
let json_str = json_str.trim();
182-
if let Ok(value) =
183-
serde_json::from_str::<VersionedValue>(json_str)
184-
{
185-
table_data.insert(key, value);
201+
let file = File::open(&snapshot_path)?;
202+
let reader = BufReader::new(file);
203+
204+
// Read line by line (each line is a Base64 encrypted chunk)
205+
for encrypted_line in reader.lines() {
206+
let encrypted_line = encrypted_line?;
207+
let trimmed_line = encrypted_line.trim();
208+
209+
if trimmed_line.is_empty() {
210+
continue;
211+
}
212+
213+
match decrypt(trimmed_line) {
214+
Ok(decrypted_content) => {
215+
// Process the decrypted block of plaintext
216+
for line in decrypted_content.lines() {
217+
if let Some((key, json_str)) = line.split_once('=') {
218+
let key = key.trim().to_string();
219+
let json_str = json_str.trim();
220+
if let Ok(value) =
221+
serde_json::from_str::<VersionedValue>(json_str)
222+
{
223+
table_data.insert(key, value);
224+
}
186225
}
187226
}
188227
}
189-
}
190-
Err(e) => {
191-
eprintln!(
192-
"Decryption failed for table {}: {}",
193-
table_name, e
194-
);
195-
process::exit(1);
228+
Err(e) => {
229+
eprintln!(
230+
"Decryption failed for table {} chunk: {}",
231+
table_name, e
232+
);
233+
process::exit(1);
234+
}
196235
}
197236
}
198237
}
199238

200-
// Replay WAL entries.
239+
// --- 2. Replay WAL entries ---
201240
let wal_path = table_folder.join("wal.log");
202241
if wal_path.exists() {
203242
let file = File::open(&wal_path)?;
204243
let reader = BufReader::new(file);
244+
205245
for line in reader.lines() {
206246
let line = line?;
207-
if let Ok(entry) =
208-
serde_json::from_str::<WalEntry>(&line)
209-
{
247+
if let Ok(entry) = serde_json::from_str::<WalEntry>(&line) {
210248
match entry.op.as_str() {
211249
"put" => {
212250
if let Some(val) = entry.value {
@@ -237,12 +275,13 @@ pub async fn load_all_tables(state: &web::Data<AppState>) -> io::Result<()> {
237275
}
238276
Ok::<HashMap<String, HashMap<String, VersionedValue>>, io::Error>(store)
239277
})
240-
.await??;
278+
.await??;
279+
280+
// Apply the newly built store to the live application state
241281
let mut store_write = state.store.write().await;
242282
*store_write = new_store;
243283
Ok(())
244284
}
245-
246285
/// Spawns a background Tokio task that periodically saves the current state to disk.
247286
pub async fn cold_save(state: web::Data<AppState>, interval: usize) {
248287
tokio::spawn(async move {

0 commit comments

Comments
 (0)