Skip to content

Commit ed10847

Browse files
issue/143 fix bench script, worker cleanup, compiler initial input
1 parent ee262bc commit ed10847

3 files changed

Lines changed: 58 additions & 14 deletions

File tree

csrc/engine/compiler/paged_compiler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
#include "paged_compiler.hpp"
22

3+
namespace {
4+
// Todo: replace with Tensor::zeros when it is available
5+
inline void set_zeros(infinicore::Tensor &tensor) {
6+
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
7+
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
8+
}
9+
10+
} // namespace
311
namespace infinilm::engine {
412
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
513
: GraphCompiler(model, barrier) {
@@ -27,22 +35,28 @@ void PagedCompiler::compile() {
2735
compiled_map_decode_.clear();
2836
block_tables_holder_ = infinicore::Tensor::empty(
2937
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
38+
set_zeros(block_tables_holder_);
3039
for (size_t b : decode_batch_sizes_) {
3140
size_t block_per_req = nblocks / b;
3241
InfinilmModel::Input input;
3342
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
3443
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
3544
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
45+
set_zeros(input.input_ids.value());
46+
set_zeros(input.position_ids.value());
47+
set_zeros(input.total_sequence_lengths.value());
3648
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
3749
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
3850
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
51+
set_zeros(input.input_offsets.value());
3952
std::vector<int64_t> input_offsets_vec(b + 1, 0);
4053
for (size_t i = 0; i <= b; i++) {
4154
input_offsets_vec[i] = i;
4255
}
4356
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
4457
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
4558
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
59+
set_zeros(input.slot_mapping.value());
4660

4761
barrier_->wait();
4862
infinicore::context::startGraphRecording();

csrc/engine/rank_worker.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,12 @@ void RankWorker::thread_loop() {
246246
try {
247247
model_->load_parameter(local_param_name, local_param);
248248
} catch (const std::exception &e) {
249-
// convert exceptions to a safe behavior: set should_exit_ and notify caller
250-
std::lock_guard<std::mutex> lk(mutex_);
251-
should_exit_ = true;
252-
job_done_ = true;
249+
{
250+
std::lock_guard<std::mutex> lk(mutex_);
251+
should_exit_ = true;
252+
job_done_ = true;
253+
}
253254
cv_.notify_all();
254-
// rethrow so the thread can be joined and caller sees an error if desired (optional)
255255
spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
256256
break;
257257
}
@@ -321,9 +321,11 @@ void RankWorker::thread_loop() {
321321
cv_.notify_all();
322322

323323
} catch (const std::exception &e) {
324-
std::lock_guard<std::mutex> lk(mutex_);
325-
should_exit_ = true;
326-
job_done_ = true;
324+
{
325+
std::lock_guard<std::mutex> lk(mutex_);
326+
should_exit_ = true;
327+
job_done_ = true;
328+
}
327329
cv_.notify_all();
328330
spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
329331
break;
@@ -338,9 +340,11 @@ void RankWorker::thread_loop() {
338340
cv_.notify_all();
339341

340342
} catch (const std::exception &e) {
341-
std::lock_guard<std::mutex> lk(mutex_);
342-
should_exit_ = true;
343-
job_done_ = true;
343+
{
344+
std::lock_guard<std::mutex> lk(mutex_);
345+
should_exit_ = true;
346+
job_done_ = true;
347+
}
344348
cv_.notify_all();
345349
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
346350
break;
@@ -357,9 +361,11 @@ void RankWorker::thread_loop() {
357361
cv_.notify_all();
358362

359363
} catch (const std::exception &e) {
360-
std::lock_guard<std::mutex> lk(mutex_);
361-
should_exit_ = true;
362-
job_done_ = true;
364+
{
365+
std::lock_guard<std::mutex> lk(mutex_);
366+
should_exit_ = true;
367+
job_done_ = true;
368+
}
363369
cv_.notify_all();
364370
spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
365371
break;
@@ -369,6 +375,9 @@ void RankWorker::thread_loop() {
369375
// Shouldn't reach here (no-op)
370376
}
371377
} // while
378+
379+
// Some clean up should be done before exiting the thread
380+
compiler_.reset();
372381
} catch (const std::exception &e) {
373382
// Top-level exception: ensure any waiters are woken and the thread exits cleanly.
374383
{

examples/bench.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,21 @@ def get_args():
137137
action="store_true",
138138
help="Run nvidia test",
139139
)
140+
parser.add_argument(
141+
"--metax",
142+
action="store_true",
143+
help="Run metax test",
144+
)
145+
parser.add_argument(
146+
"--moore",
147+
action="store_true",
148+
help="Run moore test",
149+
)
150+
parser.add_argument(
151+
"--iluvatar",
152+
action="store_true",
153+
help="Run iluvatar test",
154+
)
140155
parser.add_argument(
141156
"--cambricon",
142157
action="store_true",
@@ -328,6 +343,12 @@ def run(
328343
device_str = "cpu"
329344
elif args.nvidia:
330345
device_str = "cuda"
346+
elif args.metax:
347+
device_str = "cuda"
348+
elif args.moore:
349+
device_str = "musa"
350+
elif args.iluvatar:
351+
device_str = "cuda"
331352
elif args.cambricon:
332353
device_str = "mlu"
333354
else:

0 commit comments

Comments
 (0)