Skip to content

Commit 7992a74

Browse files
committed
fix: coalesce locator cache in-flight lookups (Fixes #399)
1 parent 172d4b1 commit 7992a74

File tree

2 files changed

+262
-11
lines changed

2 files changed

+262
-11
lines changed

crates/pet-conda/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,9 @@ impl Locator for Conda {
391391
reporter.report_environment(&env);
392392

393393
// Also check for a mamba/micromamba manager in the same directory and report it.
394-
// Reporting inside the closure minimizes the TOCTOU window compared to a
395-
// separate contains_key check, though concurrent threads may still
396-
// briefly both invoke the closure before the write-lock double-check.
394+
// LocatorCache coalesces concurrent lookups for this conda_dir, so mamba
395+
// discovery and its reporting side effect run at most once per in-flight
396+
// key.
397397
let _ = self.mamba_managers.get_or_insert_with(conda_dir.clone(), || {
398398
let mgr = get_mamba_manager(conda_dir);
399399
if let Some(ref m) = mgr {

crates/pet-core/src/cache.rs

Lines changed: 259 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
//! Provides a thread-safe cache wrapper that consolidates common caching patterns
77
//! used across multiple locators in the codebase.
88
9-
use std::{collections::HashMap, hash::Hash, path::PathBuf, sync::RwLock};
9+
use std::{
10+
collections::HashMap,
11+
hash::Hash,
12+
path::PathBuf,
13+
sync::{Arc, Condvar, Mutex, RwLock},
14+
};
1015

1116
use crate::{manager::EnvManager, python_environment::PythonEnvironment};
1217

@@ -17,13 +22,71 @@ use crate::{manager::EnvManager, python_environment::PythonEnvironment};
1722
/// returned from the cache.
1823
pub struct LocatorCache<K, V> {
1924
cache: RwLock<HashMap<K, V>>,
25+
in_flight: Mutex<HashMap<K, Arc<InFlightEntry<V>>>>,
26+
}
27+
28+
struct InFlightEntry<V> {
29+
result: Mutex<Option<Option<V>>>,
30+
changed: Condvar,
31+
}
32+
33+
struct InFlightOwnerGuard<'a, K: Eq + Hash, V> {
34+
key: Option<K>,
35+
entry: Arc<InFlightEntry<V>>,
36+
in_flight: &'a Mutex<HashMap<K, Arc<InFlightEntry<V>>>>,
37+
}
38+
39+
enum InFlightClaim<'a, K: Eq + Hash, V> {
40+
Owner(InFlightOwnerGuard<'a, K, V>),
41+
Waiter(Arc<InFlightEntry<V>>),
42+
}
43+
44+
impl<V> InFlightEntry<V> {
45+
fn new() -> Self {
46+
Self {
47+
result: Mutex::new(None),
48+
changed: Condvar::new(),
49+
}
50+
}
51+
}
52+
53+
impl<K: Eq + Hash, V> InFlightOwnerGuard<'_, K, V> {
54+
fn complete(mut self, result: Option<V>) {
55+
self.publish_result(result);
56+
}
57+
58+
fn publish_result(&mut self, result: Option<V>) {
59+
*self
60+
.entry
61+
.result
62+
.lock()
63+
.expect("locator cache in-flight result lock poisoned") = Some(result);
64+
65+
if let Some(key) = self.key.take() {
66+
self.in_flight
67+
.lock()
68+
.expect("locator cache in-flight lock poisoned")
69+
.remove(&key);
70+
}
71+
72+
self.entry.changed.notify_all();
73+
}
74+
}
75+
76+
impl<K: Eq + Hash, V> Drop for InFlightOwnerGuard<'_, K, V> {
77+
fn drop(&mut self) {
78+
if self.key.is_some() {
79+
self.publish_result(None);
80+
}
81+
}
2082
}
2183

2284
impl<K: Eq + Hash, V: Clone> LocatorCache<K, V> {
2385
/// Creates a new empty cache.
2486
pub fn new() -> Self {
2587
Self {
2688
cache: RwLock::new(HashMap::new()),
89+
in_flight: Mutex::new(HashMap::new()),
2790
}
2891
}
2992

@@ -68,35 +131,96 @@ impl<K: Eq + Hash, V: Clone> LocatorCache<K, V> {
68131
/// Returns a cloned value for the given key if it exists, otherwise computes
69132
/// and inserts the value using the provided closure.
70133
///
71-
/// This method first checks with a read lock, then upgrades to a write lock
72-
/// if the value needs to be computed and inserted.
134+
/// This method first checks with a read lock. If the key is missing, it
135+
/// claims a per-key in-flight slot before computing the value so concurrent
136+
/// callers for the same key wait for the first computation instead of
137+
/// running duplicate closures with duplicate side effects. `None` results
138+
/// are shared with current waiters but are not stored in the cache, so later
139+
/// calls can retry the computation.
73140
#[must_use]
74141
pub fn get_or_insert_with<F>(&self, key: K, f: F) -> Option<V>
75142
where
76143
F: FnOnce() -> Option<V>,
77144
K: Clone,
78145
{
79-
// First check with read lock
146+
// First check with read lock.
80147
{
81148
let cache = self.cache.read().expect("locator cache lock poisoned");
82149
if let Some(value) = cache.get(&key) {
83150
return Some(value.clone());
84151
}
85152
}
86153

154+
let in_flight = match self.claim_in_flight(&key) {
155+
InFlightClaim::Owner(in_flight) => in_flight,
156+
InFlightClaim::Waiter(entry) => return Self::wait_for_in_flight(entry),
157+
};
158+
159+
// Check again after claiming the in-flight slot. Another thread may have
160+
// completed the same key while this thread was waiting.
161+
{
162+
let cache = self.cache.read().expect("locator cache lock poisoned");
163+
if let Some(value) = cache.get(&key) {
164+
let result = Some(value.clone());
165+
in_flight.complete(result.clone());
166+
return result;
167+
}
168+
}
169+
87170
// Compute the value (outside of any lock)
88-
if let Some(value) = f() {
171+
let result = if let Some(value) = f() {
89172
// Acquire write lock and insert
90173
let mut cache = self.cache.write().expect("locator cache lock poisoned");
91174
// Double-check in case another thread inserted while we were computing
92175
if let Some(existing) = cache.get(&key) {
93-
return Some(existing.clone());
176+
Some(existing.clone())
177+
} else {
178+
cache.insert(key, value.clone());
179+
Some(value)
94180
}
95-
cache.insert(key, value.clone());
96-
Some(value)
97181
} else {
98182
None
183+
};
184+
185+
in_flight.complete(result.clone());
186+
result
187+
}
188+
189+
fn claim_in_flight(&self, key: &K) -> InFlightClaim<'_, K, V>
190+
where
191+
K: Clone,
192+
{
193+
let mut in_flight = self
194+
.in_flight
195+
.lock()
196+
.expect("locator cache in-flight lock poisoned");
197+
198+
if let Some(entry) = in_flight.get(key) {
199+
return InFlightClaim::Waiter(entry.clone());
200+
}
201+
202+
let entry = Arc::new(InFlightEntry::new());
203+
in_flight.insert(key.clone(), entry.clone());
204+
InFlightClaim::Owner(InFlightOwnerGuard {
205+
key: Some(key.clone()),
206+
entry,
207+
in_flight: &self.in_flight,
208+
})
209+
}
210+
211+
fn wait_for_in_flight(entry: Arc<InFlightEntry<V>>) -> Option<V> {
212+
let mut result = entry
213+
.result
214+
.lock()
215+
.expect("locator cache in-flight result lock poisoned");
216+
while result.is_none() {
217+
result = entry
218+
.changed
219+
.wait(result)
220+
.expect("locator cache in-flight condvar poisoned");
99221
}
222+
223+
result.clone().unwrap()
100224
}
101225

102226
/// Clears all entries from the cache.
@@ -160,6 +284,12 @@ pub type ManagerCache = LocatorCache<PathBuf, EnvManager>;
160284
#[cfg(test)]
161285
mod tests {
162286
use super::*;
287+
use std::sync::{
288+
atomic::{AtomicUsize, Ordering},
289+
mpsc, Arc, Barrier, Mutex,
290+
};
291+
use std::thread;
292+
use std::time::Duration;
163293

164294
#[test]
165295
fn test_cache_get_and_insert() {
@@ -192,6 +322,127 @@ mod tests {
192322
assert!(!cache.contains_key(&"key2".to_string()));
193323
}
194324

325+
#[test]
326+
fn test_cache_get_or_insert_with_runs_one_closure_per_key() {
327+
let cache: Arc<LocatorCache<String, i32>> = Arc::new(LocatorCache::new());
328+
let barrier = Arc::new(Barrier::new(3));
329+
let calls = Arc::new(AtomicUsize::new(0));
330+
let (started_tx, started_rx) = mpsc::channel();
331+
let (release_tx, release_rx) = mpsc::channel();
332+
let release_rx = Arc::new(Mutex::new(release_rx));
333+
let mut handles = vec![];
334+
335+
for _ in 0..2 {
336+
let cache = cache.clone();
337+
let barrier = barrier.clone();
338+
let calls = calls.clone();
339+
let started_tx = started_tx.clone();
340+
let release_rx = release_rx.clone();
341+
handles.push(thread::spawn(move || {
342+
barrier.wait();
343+
cache.get_or_insert_with("key".to_string(), || {
344+
calls.fetch_add(1, Ordering::SeqCst);
345+
started_tx.send(()).unwrap();
346+
release_rx
347+
.lock()
348+
.unwrap()
349+
.recv_timeout(Duration::from_secs(5))
350+
.unwrap();
351+
Some(42)
352+
})
353+
}));
354+
}
355+
356+
barrier.wait();
357+
started_rx.recv_timeout(Duration::from_secs(5)).unwrap();
358+
assert_eq!(calls.load(Ordering::SeqCst), 1);
359+
assert!(started_rx.try_recv().is_err());
360+
361+
release_tx.send(()).unwrap();
362+
release_tx.send(()).unwrap();
363+
364+
let mut results = handles
365+
.into_iter()
366+
.map(|handle| handle.join().unwrap())
367+
.collect::<Vec<_>>();
368+
results.sort();
369+
370+
assert_eq!(results, vec![Some(42), Some(42)]);
371+
assert_eq!(calls.load(Ordering::SeqCst), 1);
372+
}
373+
374+
#[test]
375+
fn test_cache_get_or_insert_with_shares_concurrent_none_result() {
376+
let cache: Arc<LocatorCache<String, i32>> = Arc::new(LocatorCache::new());
377+
let barrier = Arc::new(Barrier::new(3));
378+
let calls = Arc::new(AtomicUsize::new(0));
379+
let (started_tx, started_rx) = mpsc::channel();
380+
let (release_tx, release_rx) = mpsc::channel();
381+
let release_rx = Arc::new(Mutex::new(release_rx));
382+
let mut handles = vec![];
383+
384+
for _ in 0..2 {
385+
let cache = cache.clone();
386+
let barrier = barrier.clone();
387+
let calls = calls.clone();
388+
let started_tx = started_tx.clone();
389+
let release_rx = release_rx.clone();
390+
handles.push(thread::spawn(move || {
391+
barrier.wait();
392+
cache.get_or_insert_with("key".to_string(), || {
393+
calls.fetch_add(1, Ordering::SeqCst);
394+
started_tx.send(()).unwrap();
395+
release_rx
396+
.lock()
397+
.unwrap()
398+
.recv_timeout(Duration::from_secs(5))
399+
.unwrap();
400+
None
401+
})
402+
}));
403+
}
404+
405+
barrier.wait();
406+
started_rx.recv_timeout(Duration::from_secs(5)).unwrap();
407+
assert_eq!(calls.load(Ordering::SeqCst), 1);
408+
assert!(started_rx.try_recv().is_err());
409+
410+
release_tx.send(()).unwrap();
411+
release_tx.send(()).unwrap();
412+
413+
let results = handles
414+
.into_iter()
415+
.map(|handle| handle.join().unwrap())
416+
.collect::<Vec<_>>();
417+
418+
assert_eq!(results, vec![None, None]);
419+
assert_eq!(calls.load(Ordering::SeqCst), 1);
420+
assert!(!cache.contains_key(&"key".to_string()));
421+
422+
assert_eq!(
423+
cache.get_or_insert_with("key".to_string(), || Some(42)),
424+
Some(42)
425+
);
426+
}
427+
428+
#[test]
429+
fn test_cache_get_or_insert_with_panic_releases_in_flight_key() {
430+
let cache: LocatorCache<String, i32> = LocatorCache::new();
431+
432+
let result = std::panic::catch_unwind(|| {
433+
let _ = cache.get_or_insert_with("key".to_string(), || -> Option<i32> {
434+
panic!("boom");
435+
});
436+
});
437+
438+
assert!(result.is_err());
439+
assert!(!cache.contains_key(&"key".to_string()));
440+
assert_eq!(
441+
cache.get_or_insert_with("key".to_string(), || Some(42)),
442+
Some(42)
443+
);
444+
}
445+
195446
#[test]
196447
fn test_cache_clear() {
197448
let cache: LocatorCache<String, i32> = LocatorCache::new();

0 commit comments

Comments
 (0)