Skip to content

Commit 725b71d

Browse files
WeakValueHashMap: add sharded storage to reduce mutex contention
1 parent b678e67 commit 725b71d

File tree

2 files changed

+79
-56
lines changed

2 files changed

+79
-56
lines changed

Common/interface/WeakValueHashMap.hpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <unordered_map>
3333
#include <memory>
3434
#include <mutex>
35+
#include <vector>
3536

3637
#include "../../Platforms/Basic/interface/DebugUtilities.hpp"
3738

@@ -70,7 +71,14 @@ class WeakValueHashMap
7071
class Impl;
7172

7273
public:
73-
WeakValueHashMap() = default;
74+
explicit WeakValueHashMap(size_t NumShards = 1) :
75+
m_pImpl(NumShards != 0 ? NumShards : 1)
76+
{
77+
for (std::shared_ptr<Impl>& pImpl : m_pImpl)
78+
{
79+
pImpl = std::make_shared<Impl>();
80+
}
81+
}
7482

7583
// clang-format off
7684
WeakValueHashMap (const WeakValueHashMap&) = delete;
@@ -136,16 +144,23 @@ class WeakValueHashMap
136144

137145
ValueHandle Get(const KeyType& Key) const
138146
{
139-
return m_pImpl->Get(Key);
147+
const size_t ShardIdx = GetShardIndex(Key);
148+
return m_pImpl[ShardIdx]->Get(Key);
140149
}
141150

142151
template <typename... ArgsType>
143152
ValueHandle GetOrInsert(const KeyType& Key, ArgsType&&... Args) const
144153
{
145-
return m_pImpl->GetOrInsert(Key, std::forward<ArgsType>(Args)...);
154+
const size_t ShardIdx = GetShardIndex(Key);
155+
return m_pImpl[ShardIdx]->GetOrInsert(Key, std::forward<ArgsType>(Args)...);
146156
}
147157

148158
private:
159+
size_t GetShardIndex(const KeyType& Key) const
160+
{
161+
return m_pImpl.size() > 1 ? m_Hasher(Key) % m_pImpl.size() : 0;
162+
}
163+
149164
class Impl : public std::enable_shared_from_this<Impl>
150165
{
151166
public:
@@ -184,7 +199,8 @@ class WeakValueHashMap
184199
// Create shared_ptr with deleter that erases the Key from the map when
185200
// the last reference to the value is destroyed
186201
std::shared_ptr<ValueType> pNewValue{
187-
new ValueType{std::forward<ArgsType>(Args)...},
202+
// Use parentheses to avoid initializer_list surprises.
203+
new ValueType(std::forward<ArgsType>(Args)...),
188204
[WeakSelf = std::move(WeakSelf), Key](ValueType* pValue) {
189205
// Get a pointer before deleting the value as the value keeps the map alive.
190206
std::shared_ptr<Impl> Self = WeakSelf.lock();
@@ -249,7 +265,8 @@ class WeakValueHashMap
249265
std::mutex m_Mtx;
250266
std::unordered_map<KeyType, std::weak_ptr<ValueType>, Hasher, Keyeq> m_Map;
251267
};
252-
std::shared_ptr<Impl> m_pImpl = std::make_shared<Impl>();
268+
std::vector<std::shared_ptr<Impl>> m_pImpl;
269+
Hasher m_Hasher;
253270
};
254271

255272
} // namespace Diligent

Tests/DiligentCoreTest/src/Common/WeakValueHashMapTest.cpp

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ TEST(Common_WeakValueHashMap, GetOrInsert)
5858
auto Handle1 = Map.GetOrInsert(1, "Value");
5959

6060
// Release map while the handle is still alive
61-
Map = {};
61+
Map = WeakValueHashMap<int, std::string>{};
6262

6363
EXPECT_TRUE(Handle1);
6464
EXPECT_STREQ(Handle1->c_str(), "Value");
@@ -71,7 +71,7 @@ TEST(Common_WeakValueHashMap, GetOrInsert)
7171
auto Handle1 = Map.GetOrInsert(1, "Value");
7272

7373
// Release map while the handle is still alive
74-
Map = {};
74+
Map = WeakValueHashMap<int, std::string>{};
7575

7676
WeakValueHashMap<int, std::string>::ValueHandle Handle2{std::move(Handle1)};
7777
EXPECT_FALSE(Handle1);
@@ -86,7 +86,7 @@ TEST(Common_WeakValueHashMap, GetOrInsert)
8686
auto Handle1 = Map.GetOrInsert(1, "Value");
8787

8888
// Release map while the handle is still alive
89-
Map = {};
89+
Map = WeakValueHashMap<int, std::string>{};
9090

9191
WeakValueHashMap<int, std::string>::ValueHandle Handle2;
9292
Handle2 = std::move(Handle1);
@@ -207,72 +207,78 @@ static constexpr int kNumParallelKeys = 16384;
207207
// Test that multiple threads can concurrently get or insert values into the map
208208
TEST(Common_WeakValueHashMap, ParallelGetOrInsert1)
209209
{
210-
std::vector<std::thread> Threads(kNumThreads);
211-
212-
Threading::Signal StartSignal;
213-
214-
WeakValueHashMap<int, std::string> Map;
215-
for (size_t t = 0; t < kNumThreads; ++t)
210+
for (size_t NumShards : {1, 2, 4})
216211
{
217-
Threads[t] = std::thread{
218-
[&Map, &StartSignal]() //
219-
{
220-
StartSignal.Wait(true, kNumThreads);
212+
std::vector<std::thread> Threads(kNumThreads);
221213

222-
for (int k = 0; k < kNumParallelKeys; ++k)
214+
Threading::Signal StartSignal;
215+
216+
WeakValueHashMap<int, std::string> Map{NumShards};
217+
for (size_t t = 0; t < kNumThreads; ++t)
218+
{
219+
Threads[t] = std::thread{
220+
[&Map, &StartSignal]() //
223221
{
224-
std::string Value = "Value" + std::to_string(k);
222+
StartSignal.Wait(true, kNumThreads);
225223

226-
auto Handle = Map.GetOrInsert(k, Value);
227-
EXPECT_TRUE(Handle);
228-
EXPECT_EQ(*Handle, Value);
229-
}
230-
}};
231-
}
224+
for (int k = 0; k < kNumParallelKeys; ++k)
225+
{
226+
std::string Value = "Value" + std::to_string(k);
232227

233-
StartSignal.Trigger(true);
234-
for (auto& Thread : Threads)
235-
{
236-
Thread.join();
228+
auto Handle = Map.GetOrInsert(k, Value);
229+
EXPECT_TRUE(Handle);
230+
EXPECT_EQ(*Handle, Value);
231+
}
232+
}};
233+
}
234+
235+
StartSignal.Trigger(true);
236+
for (auto& Thread : Threads)
237+
{
238+
Thread.join();
239+
}
237240
}
238241
}
239242

240243
// Similar to the previous test, but all values are kept alive
241244
TEST(Common_WeakValueHashMap, ParallelGetOrInsert2)
242245
{
243-
std::vector<std::thread> Threads(kNumThreads);
244-
245-
Threading::Signal StartSignal;
246+
for (size_t NumShards : {1, 2, 4})
247+
{
248+
std::vector<std::thread> Threads(kNumThreads);
246249

247-
std::vector<WeakValueHashMap<int, std::string>::ValueHandle> Handles(kNumThreads * kNumParallelKeys);
250+
Threading::Signal StartSignal;
248251

249-
WeakValueHashMap<int, std::string> Map;
250-
for (size_t t = 0; t < kNumThreads; ++t)
251-
{
252-
Threads[t] = std::thread{
253-
[&Map, &StartSignal, &Handles](size_t ThreadId) //
254-
{
255-
StartSignal.Wait(true, kNumThreads);
252+
std::vector<WeakValueHashMap<int, std::string>::ValueHandle> Handles(kNumThreads * kNumParallelKeys);
256253

257-
for (int k = 0; k < kNumParallelKeys; ++k)
254+
WeakValueHashMap<int, std::string> Map{NumShards};
255+
for (size_t t = 0; t < kNumThreads; ++t)
256+
{
257+
Threads[t] = std::thread{
258+
[&Map, &StartSignal, &Handles](size_t ThreadId) //
258259
{
259-
std::string Value = "Value" + std::to_string(k);
260+
StartSignal.Wait(true, kNumThreads);
260261

261-
auto Handle = Map.GetOrInsert(k, Value);
262-
EXPECT_TRUE(Handle);
263-
EXPECT_EQ(*Handle, Value);
262+
for (int k = 0; k < kNumParallelKeys; ++k)
263+
{
264+
std::string Value = "Value" + std::to_string(k);
264265

265-
Handles[ThreadId * kNumParallelKeys + k] = std::move(Handle);
266-
}
267-
},
268-
t,
269-
};
270-
}
266+
auto Handle = Map.GetOrInsert(k, Value);
267+
EXPECT_TRUE(Handle);
268+
EXPECT_EQ(*Handle, Value);
271269

272-
StartSignal.Trigger(true);
273-
for (auto& Thread : Threads)
274-
{
275-
Thread.join();
270+
Handles[ThreadId * kNumParallelKeys + k] = std::move(Handle);
271+
}
272+
},
273+
t,
274+
};
275+
}
276+
277+
StartSignal.Trigger(true);
278+
for (auto& Thread : Threads)
279+
{
280+
Thread.join();
281+
}
276282
}
277283
}
278284

0 commit comments

Comments
 (0)