Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/lazy_blocks_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ impl<T, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
}
}

// todo: allow non-default construction
impl<T: Default, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
pub fn get_or_create_default(&mut self, index: usize) -> &mut T {
impl<T, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
pub fn get_mut(&mut self, index: usize) -> &mut Option<T> {
let block_index = index / ELEMENTS_PER_BLOCK;
let blocks_len = self.blocks.len();

Expand All @@ -27,7 +26,7 @@ impl<T: Default, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_
}

let block = self.blocks[block_index].get_or_insert_default();
block.data[index - block_index * ELEMENTS_PER_BLOCK].get_or_insert_default()
&mut block.data[index - block_index * ELEMENTS_PER_BLOCK]
}
}

Expand All @@ -53,11 +52,11 @@ mod tests {
let mut vec = LazyBlocksVec::<i32, 8>::default();

{
let elem = vec.get_or_create_default(1);
*elem = 1;
let elem = vec.get_mut(1);
*elem = Some(1);
}
assert_eq!(vec.blocks.len(), 32);
assert_eq!(vec.get_or_create_default(1), &1);
assert_eq!(vec.get_or_create_default(2), &0);
assert_eq!(vec.get_mut(1), &Some(1));
assert_eq!(vec.get_mut(2), &None);
}
}
76 changes: 55 additions & 21 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub type TlsVec<T> = LazyBlocksVec<Arc<T>, 32>;
pub type TlsRegistry<T> = LocalKey<RefCell<TlsVec<T>>>;

pub struct EnumerableTls<
T: Default + 'static,
T: 'static,
IdProvider: GlobalProvider<Mutex<FreeIds>>,
TlsProvider: GlobalProvider<TlsRegistry<T>>,
> {
Expand All @@ -26,7 +26,7 @@ pub struct EnumerableTls<
}

impl<
T: Default + 'static,
T: 'static,
IdProvider: GlobalProvider<Mutex<FreeIds>>,
TlsProvider: GlobalProvider<TlsRegistry<T>>,
> Default for EnumerableTls<T, IdProvider, TlsProvider>
Expand All @@ -37,7 +37,7 @@ impl<
}

impl<
T: Default + 'static,
T: 'static,
IdProvider: GlobalProvider<Mutex<FreeIds>>,
TlsProvider: GlobalProvider<TlsRegistry<T>>,
> EnumerableTls<T, IdProvider, TlsProvider>
Expand All @@ -50,19 +50,6 @@ impl<
}
}

pub fn get_or_create(&self) -> Arc<T> {
let wrapper = TlsProvider::global()
.with_borrow_mut(|blocks| blocks.get_or_create_default(self.tls_id.inner()).clone());

{
let mut guard = self.all_tls.lock().unwrap();
guard.push(Arc::downgrade(&wrapper));
guard.retain(|wrapper| wrapper.upgrade().is_some());
}

wrapper
}

pub fn for_each(&self, mut f: impl FnMut(Arc<T>)) {
let mut wrappers_guard = self.all_tls.lock().unwrap();
wrappers_guard.retain(|wrapper| {
Expand All @@ -76,6 +63,33 @@ impl<
}
}

impl<
T: Default + 'static,
IdProvider: GlobalProvider<Mutex<FreeIds>>,
TlsProvider: GlobalProvider<TlsRegistry<T>>,
> EnumerableTls<T, IdProvider, TlsProvider>
{
pub fn get_or_create(&self) -> Arc<T> {
self.modify(|wrapper| wrapper.clone())
}

pub fn modify<U>(&self, mut f: impl FnMut(&Arc<T>) -> U) -> U {
TlsProvider::global().with_borrow_mut(|blocks| {
let wrapper = blocks.get_mut(self.tls_id.inner());
let register = wrapper.is_none();
let wrapper = wrapper.get_or_insert_default();

if register {
let mut guard = self.all_tls.lock().unwrap();
guard.push(Arc::downgrade(wrapper));
guard.retain(|wrapper| wrapper.upgrade().is_some());
}

f(wrapper)
})
}
}

#[macro_export]
macro_rules! declare_enumerable_tls {
($TlsType:ident, $Data:ty) => {
Expand Down Expand Up @@ -111,13 +125,13 @@ mod tests {

use static_assertions::assert_impl_all;

declare_enumerable_tls!(MyEnumerableTls, AtomicI32);

assert_impl_all!(MyEnumerableTls: Send, Sync);

#[test]
fn test() {
let tls = MyEnumerableTls::new();
declare_enumerable_tls!(TestEnumerableTls, AtomicI32);

assert_impl_all!(TestEnumerableTls: Send, Sync);

let tls = TestEnumerableTls::new();

let sum_data = || {
let mut sum = 0;
Expand Down Expand Up @@ -151,4 +165,24 @@ mod tests {
sleep(Duration::from_millis(20));
assert_eq!(sum_data(), 1);
}

#[test]
fn get_twice() {
declare_enumerable_tls!(GetTwiceEnumerableTls, AtomicI32);

let tls = GetTwiceEnumerableTls::new();

let data0 = tls.get_or_create();
data0.fetch_add(1, Ordering::SeqCst);

let data1 = tls.get_or_create();
data1.fetch_add(1, Ordering::SeqCst);

let mut sum = 0;
tls.for_each(|data| {
sum += data.load(Ordering::SeqCst);
});

assert_eq!(sum, 2);
}
}