-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_loader.cpp
More file actions
99 lines (83 loc) · 2.37 KB
/
Copy pathdata_loader.cpp
File metadata and controls
99 lines (83 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <algorithm>
#include <condition_variable>
#include <cstdint>
#include <future>
#include <mutex>
#include <optional>
#include <queue>
#include <thread>
#include <utility>
#include "logging.h"
#include "tguf.h"
namespace tguf {
template <typename T>
AsyncDataLoader<T>::AsyncDataLoader(std::size_t prefetch_factor)
: prefetch_factor_(prefetch_factor) {}
template <typename T>
AsyncDataLoader<T>::~AsyncDataLoader() {
stop();
}
template <typename T>
template <typename Producer>
auto AsyncDataLoader<T>::start(std::size_t start_idx, std::size_t end_idx,
std::size_t batch_size, Producer&& producer)
-> void {
stop_ = false;
worker_ = std::thread([this, start_idx, end_idx, batch_size,
fn = std::forward<Producer>(producer)]() mutable {
for (auto i = start_idx; i < end_idx; i += batch_size) {
auto current_batch_size = std::min(batch_size, end_idx - i);
// Wait for space in the prefetch buffer
std::unique_lock<std::mutex> lock(mtx_);
cv_full_.wait(lock,
[this] { return q_.size() < prefetch_factor_ || stop_; });
if (stop_) {
break;
}
// Launch the task. We pass 'fn' by value into the async lambda.
auto task = std::async(std::launch::async, [fn, i, current_batch_size] {
return fn(i, current_batch_size);
});
q_.push(std::move(task));
// Signal the consumer
lock.unlock();
cv_empty_.notify_one();
}
});
}
template <typename T>
auto AsyncDataLoader<T>::stop() -> void {
{
std::lock_guard<std::mutex> lock(mtx_);
if (stop_) {
return;
}
stop_ = true;
}
cv_full_.notify_all();
cv_empty_.notify_all();
if (worker_.joinable()) {
worker_.join();
}
std::lock_guard<std::mutex> lock(mtx_);
while (!q_.empty()) {
q_.pop();
}
}
template <typename T>
auto AsyncDataLoader<T>::next() -> std::optional<T> {
std::unique_lock<std::mutex> lock(mtx_);
// Wait for a task to be available or for the loader to stop
cv_empty_.wait(lock, [this] { return !q_.empty() || stop_; });
if (q_.empty()) {
return std::nullopt;
}
// Move the future out of the queue
auto fut = std::move(q_.front());
q_.pop();
// Unlock before blocking on .get() to allow the producer to continue
lock.unlock();
cv_full_.notify_one();
return fut.get();
}
} // namespace tguf