diff --git a/src/lazy_blocks_vec.rs b/src/lazy_blocks_vec.rs index 887e2f8..743b7e4 100644 --- a/src/lazy_blocks_vec.rs +++ b/src/lazy_blocks_vec.rs @@ -14,9 +14,8 @@ impl LazyBlocksVec { } } -// todo: allow non-default construction -impl LazyBlocksVec { - pub fn get_or_create_default(&mut self, index: usize) -> &mut T { +impl LazyBlocksVec { + pub fn get_mut(&mut self, index: usize) -> &mut Option { let block_index = index / ELEMENTS_PER_BLOCK; let blocks_len = self.blocks.len(); @@ -27,7 +26,7 @@ impl LazyBlocksVec::default(); { - let elem = vec.get_or_create_default(1); - *elem = 1; + let elem = vec.get_mut(1); + *elem = Some(1); } assert_eq!(vec.blocks.len(), 32); - assert_eq!(vec.get_or_create_default(1), &1); - assert_eq!(vec.get_or_create_default(2), &0); + assert_eq!(vec.get_mut(1), &Some(1)); + assert_eq!(vec.get_mut(2), &None); } } diff --git a/src/tls.rs b/src/tls.rs index 8ecea9f..95962d7 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -16,7 +16,7 @@ pub type TlsVec = LazyBlocksVec, 32>; pub type TlsRegistry = LocalKey>>; pub struct EnumerableTls< - T: Default + 'static, + T: 'static, IdProvider: GlobalProvider>, TlsProvider: GlobalProvider>, > { @@ -26,7 +26,7 @@ pub struct EnumerableTls< } impl< - T: Default + 'static, + T: 'static, IdProvider: GlobalProvider>, TlsProvider: GlobalProvider>, > Default for EnumerableTls @@ -37,7 +37,7 @@ impl< } impl< - T: Default + 'static, + T: 'static, IdProvider: GlobalProvider>, TlsProvider: GlobalProvider>, > EnumerableTls @@ -50,19 +50,6 @@ impl< } } - pub fn get_or_create(&self) -> Arc { - let wrapper = TlsProvider::global() - .with_borrow_mut(|blocks| blocks.get_or_create_default(self.tls_id.inner()).clone()); - - { - let mut guard = self.all_tls.lock().unwrap(); - guard.push(Arc::downgrade(&wrapper)); - guard.retain(|wrapper| wrapper.upgrade().is_some()); - } - - wrapper - } - pub fn for_each(&self, mut f: impl FnMut(Arc)) { let mut wrappers_guard = self.all_tls.lock().unwrap(); wrappers_guard.retain(|wrapper| { @@ -76,6 +63,33 @@ impl< } } +impl< + T: Default + 'static, + IdProvider: GlobalProvider>, + TlsProvider: GlobalProvider>, +> EnumerableTls +{ + pub fn get_or_create(&self) -> Arc { + self.modify(|wrapper| wrapper.clone()) + } + + pub fn modify(&self, mut f: impl FnMut(&Arc) -> U) -> U { + TlsProvider::global().with_borrow_mut(|blocks| { + let wrapper = blocks.get_mut(self.tls_id.inner()); + let register = wrapper.is_none(); + let wrapper = wrapper.get_or_insert_default(); + + if register { + let mut guard = self.all_tls.lock().unwrap(); + guard.push(Arc::downgrade(wrapper)); + guard.retain(|wrapper| wrapper.upgrade().is_some()); + } + + f(wrapper) + }) + } +} + #[macro_export] macro_rules! declare_enumerable_tls { ($TlsType:ident, $Data:ty) => { @@ -111,13 +125,13 @@ mod tests { use static_assertions::assert_impl_all; - declare_enumerable_tls!(MyEnumerableTls, AtomicI32); - - assert_impl_all!(MyEnumerableTls: Send, Sync); - #[test] fn test() { - let tls = MyEnumerableTls::new(); + declare_enumerable_tls!(TestEnumerableTls, AtomicI32); + + assert_impl_all!(TestEnumerableTls: Send, Sync); + + let tls = TestEnumerableTls::new(); let sum_data = || { let mut sum = 0; @@ -151,4 +165,24 @@ mod tests { sleep(Duration::from_millis(20)); assert_eq!(sum_data(), 1); } + + #[test] + fn get_twice() { + declare_enumerable_tls!(GetTwiceEnumerableTls, AtomicI32); + + let tls = GetTwiceEnumerableTls::new(); + + let data0 = tls.get_or_create(); + data0.fetch_add(1, Ordering::SeqCst); + + let data1 = tls.get_or_create(); + data1.fetch_add(1, Ordering::SeqCst); + + let mut sum = 0; + tls.for_each(|data| { + sum += data.load(Ordering::SeqCst); + }); + + assert_eq!(sum, 2); + } }