Skip to content

Commit bdba122

Browse files
fix: test get_twice (#5)
1 parent cfdc991 commit bdba122

2 files changed

Lines changed: 62 additions & 29 deletions

File tree

src/lazy_blocks_vec.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ impl<T, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
1414
}
1515
}
1616

17-
// todo: allow non-default construction
18-
impl<T: Default, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
19-
pub fn get_or_create_default(&mut self, index: usize) -> &mut T {
17+
impl<T, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_BLOCK> {
18+
pub fn get_mut(&mut self, index: usize) -> &mut Option<T> {
2019
let block_index = index / ELEMENTS_PER_BLOCK;
2120
let blocks_len = self.blocks.len();
2221

@@ -27,7 +26,7 @@ impl<T: Default, const ELEMENTS_PER_BLOCK: usize> LazyBlocksVec<T, ELEMENTS_PER_
2726
}
2827

2928
let block = self.blocks[block_index].get_or_insert_default();
30-
block.data[index - block_index * ELEMENTS_PER_BLOCK].get_or_insert_default()
29+
&mut block.data[index - block_index * ELEMENTS_PER_BLOCK]
3130
}
3231
}
3332

@@ -53,11 +52,11 @@ mod tests {
5352
let mut vec = LazyBlocksVec::<i32, 8>::default();
5453

5554
{
56-
let elem = vec.get_or_create_default(1);
57-
*elem = 1;
55+
let elem = vec.get_mut(1);
56+
*elem = Some(1);
5857
}
5958
assert_eq!(vec.blocks.len(), 32);
60-
assert_eq!(vec.get_or_create_default(1), &1);
61-
assert_eq!(vec.get_or_create_default(2), &0);
59+
assert_eq!(vec.get_mut(1), &Some(1));
60+
assert_eq!(vec.get_mut(2), &None);
6261
}
6362
}

src/tls.rs

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub type TlsVec<T> = LazyBlocksVec<Arc<T>, 32>;
1616
pub type TlsRegistry<T> = LocalKey<RefCell<TlsVec<T>>>;
1717

1818
pub struct EnumerableTls<
19-
T: Default + 'static,
19+
T: 'static,
2020
IdProvider: GlobalProvider<Mutex<FreeIds>>,
2121
TlsProvider: GlobalProvider<TlsRegistry<T>>,
2222
> {
@@ -26,7 +26,7 @@ pub struct EnumerableTls<
2626
}
2727

2828
impl<
29-
T: Default + 'static,
29+
T: 'static,
3030
IdProvider: GlobalProvider<Mutex<FreeIds>>,
3131
TlsProvider: GlobalProvider<TlsRegistry<T>>,
3232
> Default for EnumerableTls<T, IdProvider, TlsProvider>
@@ -37,7 +37,7 @@ impl<
3737
}
3838

3939
impl<
40-
T: Default + 'static,
40+
T: 'static,
4141
IdProvider: GlobalProvider<Mutex<FreeIds>>,
4242
TlsProvider: GlobalProvider<TlsRegistry<T>>,
4343
> EnumerableTls<T, IdProvider, TlsProvider>
@@ -50,19 +50,6 @@ impl<
5050
}
5151
}
5252

53-
pub fn get_or_create(&self) -> Arc<T> {
54-
let wrapper = TlsProvider::global()
55-
.with_borrow_mut(|blocks| blocks.get_or_create_default(self.tls_id.inner()).clone());
56-
57-
{
58-
let mut guard = self.all_tls.lock().unwrap();
59-
guard.push(Arc::downgrade(&wrapper));
60-
guard.retain(|wrapper| wrapper.upgrade().is_some());
61-
}
62-
63-
wrapper
64-
}
65-
6653
pub fn for_each(&self, mut f: impl FnMut(Arc<T>)) {
6754
let mut wrappers_guard = self.all_tls.lock().unwrap();
6855
wrappers_guard.retain(|wrapper| {
@@ -76,6 +63,33 @@ impl<
7663
}
7764
}
7865

66+
impl<
67+
T: Default + 'static,
68+
IdProvider: GlobalProvider<Mutex<FreeIds>>,
69+
TlsProvider: GlobalProvider<TlsRegistry<T>>,
70+
> EnumerableTls<T, IdProvider, TlsProvider>
71+
{
72+
pub fn get_or_create(&self) -> Arc<T> {
73+
self.modify(|wrapper| wrapper.clone())
74+
}
75+
76+
pub fn modify<U>(&self, mut f: impl FnMut(&Arc<T>) -> U) -> U {
77+
TlsProvider::global().with_borrow_mut(|blocks| {
78+
let wrapper = blocks.get_mut(self.tls_id.inner());
79+
let register = wrapper.is_none();
80+
let wrapper = wrapper.get_or_insert_default();
81+
82+
if register {
83+
let mut guard = self.all_tls.lock().unwrap();
84+
guard.push(Arc::downgrade(wrapper));
85+
guard.retain(|wrapper| wrapper.upgrade().is_some());
86+
}
87+
88+
f(wrapper)
89+
})
90+
}
91+
}
92+
7993
#[macro_export]
8094
macro_rules! declare_enumerable_tls {
8195
($TlsType:ident, $Data:ty) => {
@@ -111,13 +125,13 @@ mod tests {
111125

112126
use static_assertions::assert_impl_all;
113127

114-
declare_enumerable_tls!(MyEnumerableTls, AtomicI32);
115-
116-
assert_impl_all!(MyEnumerableTls: Send, Sync);
117-
118128
#[test]
119129
fn test() {
120-
let tls = MyEnumerableTls::new();
130+
declare_enumerable_tls!(TestEnumerableTls, AtomicI32);
131+
132+
assert_impl_all!(TestEnumerableTls: Send, Sync);
133+
134+
let tls = TestEnumerableTls::new();
121135

122136
let sum_data = || {
123137
let mut sum = 0;
@@ -151,4 +165,24 @@ mod tests {
151165
sleep(Duration::from_millis(20));
152166
assert_eq!(sum_data(), 1);
153167
}
168+
169+
#[test]
170+
fn get_twice() {
171+
declare_enumerable_tls!(GetTwiceEnumerableTls, AtomicI32);
172+
173+
let tls = GetTwiceEnumerableTls::new();
174+
175+
let data0 = tls.get_or_create();
176+
data0.fetch_add(1, Ordering::SeqCst);
177+
178+
let data1 = tls.get_or_create();
179+
data1.fetch_add(1, Ordering::SeqCst);
180+
181+
let mut sum = 0;
182+
tls.for_each(|data| {
183+
sum += data.load(Ordering::SeqCst);
184+
});
185+
186+
assert_eq!(sum, 2);
187+
}
154188
}

0 commit comments

Comments
 (0)