Skip to content

Commit e380559

Browse files
authored
Issue/414: Async model loader (#415)
Batch and parallelize model parameter loading by safetensor shard across rank workers. Add vLLM-style safetensors loading with shard prefetch/pipelining, state dict reuse, missing-key validation, progress restoration.
1 parent fb2f5a4 commit e380559

6 files changed

Lines changed: 79 additions & 3 deletions

File tree

csrc/engine/infer_engine.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "infer_engine.hpp"
22
#include "../config/config_factory.hpp"
33
#include "spdlog/spdlog.h"
4+
#include <future>
45

56
namespace infinilm::engine {
67

@@ -54,6 +55,20 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor &
5455
worker->load_param(name, param);
5556
}
5657
}
58+
59+
void InferEngine::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
60+
std::vector<std::future<void>> futures;
61+
futures.reserve(workers_.size());
62+
for (auto &worker : workers_) {
63+
futures.emplace_back(std::async(std::launch::async, [&worker, &params] {
64+
worker->load_params(params);
65+
}));
66+
}
67+
for (auto &future : futures) {
68+
future.get();
69+
}
70+
}
71+
5772
//------------------------------------------------------
5873
// load_param
5974
//------------------------------------------------------

csrc/engine/infer_engine.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "rank_worker.hpp"
1010

1111
#include <optional>
12+
#include <unordered_map>
1213
#include <vector>
1314

1415
namespace infinilm::engine {
@@ -32,6 +33,9 @@ class InferEngine {
3233
// Load a parameter to all workers (each can extract its shard inside RankWorker)
3334
void load_param(const std::string &name, const infinicore::Tensor &param);
3435

36+
// Load a batch of parameters to all workers, syncing each worker once after the batch.
37+
void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params);
38+
3539
// process the weights after loading on all workers (e.g., for quantization)
3640
void process_weights_after_loading();
3741

csrc/engine/rank_worker.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,31 @@ void RankWorker::load_param(const std::string &name,
8888
}
8989
}
9090

91+
//------------------------------------------------------
92+
// load_params -- synchronous batch load
93+
//------------------------------------------------------
94+
void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
95+
{
96+
std::lock_guard<std::mutex> lock(mutex_);
97+
if (should_exit_) {
98+
throw std::runtime_error("RankWorker is closing; cannot load_params");
99+
}
100+
101+
pending_params_ = params;
102+
job_cmd_ = Command::LOAD_BATCH;
103+
has_job_ = true;
104+
job_done_ = false;
105+
}
106+
cv_.notify_all();
107+
108+
std::unique_lock<std::mutex> lk(mutex_);
109+
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
110+
111+
if (should_exit_) {
112+
throw std::runtime_error("RankWorker stopped while loading parameters");
113+
}
114+
}
115+
91116
//------------------------------------------------------
92117
// process_weights_after_loading -- asynchronous
93118
//------------------------------------------------------
@@ -266,6 +291,7 @@ void RankWorker::thread_loop() {
266291
Command local_cmd = Command::INIT;
267292
std::string local_param_name;
268293
infinicore::Tensor local_param;
294+
std::unordered_map<std::string, infinicore::Tensor> local_params;
269295
Input local_args;
270296
std::unique_ptr<cache::CacheConfig> local_cache_config;
271297

@@ -283,6 +309,9 @@ void RankWorker::thread_loop() {
283309
if (local_cmd == Command::LOAD) {
284310
local_param_name = pending_param_name_;
285311
local_param = pending_param_;
312+
} else if (local_cmd == Command::LOAD_BATCH) {
313+
local_params = std::move(pending_params_);
314+
pending_params_.clear();
286315
} else if (local_cmd == Command::PREPROCESS) {
287316

288317
} else if (local_cmd == Command::RUN) {
@@ -319,6 +348,27 @@ void RankWorker::thread_loop() {
319348
}
320349
cv_.notify_all();
321350

351+
} else if (local_cmd == Command::LOAD_BATCH) {
352+
try {
353+
model_->load_parameters_no_sync(local_params);
354+
infinicore::context::syncStream();
355+
} catch (const std::exception &e) {
356+
{
357+
std::lock_guard<std::mutex> lk(mutex_);
358+
should_exit_ = true;
359+
job_done_ = true;
360+
}
361+
cv_.notify_all();
362+
spdlog::error("[{}] exception during load_parameters_: {}\n", info(), e.what());
363+
break;
364+
}
365+
366+
{
367+
std::lock_guard<std::mutex> lk(mutex_);
368+
job_done_ = true;
369+
}
370+
cv_.notify_all();
371+
322372
} else if (local_cmd == Command::PREPROCESS) {
323373
// Handle preprocess command
324374
try {

csrc/engine/rank_worker.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <random>
1616
#include <string>
1717
#include <thread>
18+
#include <unordered_map>
1819
#include <vector>
1920

2021
namespace infinilm::engine {
@@ -25,6 +26,7 @@ class RankWorker {
2526
enum class Command {
2627
INIT,
2728
LOAD,
29+
LOAD_BATCH,
2830
PREPROCESS,
2931
RUN,
3032
RESET_CACHE,
@@ -83,6 +85,8 @@ class RankWorker {
8385
void load_param(const std::string &name,
8486
const infinicore::Tensor &param);
8587

88+
void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params);
89+
8690
void process_weights_after_loading();
8791

8892
// return the parameters (i.e. weights and biases).
@@ -139,6 +143,7 @@ class RankWorker {
139143
// Task payloads (protected by mutex)
140144
std::string pending_param_name_;
141145
infinicore::Tensor pending_param_;
146+
std::unordered_map<std::string, infinicore::Tensor> pending_params_;
142147
Input pending_args_;
143148
std::unique_ptr<cache::CacheConfig> pending_cache_config_;
144149

csrc/pybind11/engine/engine.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ inline void bind_infer_engine(py::module &m) {
5959
.def("load_param", &InferEngine::load_param,
6060
py::arg("name"), py::arg("param"),
6161
"Load a parameter tensor into all workers (each worker picks its shard)")
62+
.def("load_params", &InferEngine::load_params,
63+
py::arg("params"),
64+
"Load a batch of parameter tensors into all workers, syncing once per worker")
6265
.def("state_dict", [](InferEngine &self) {
6366
py::list state_dict_tp_all;
6467
for (const auto &state_dict_tp : self.state_dict()) {

python/infinilm/infer_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,10 @@ def reset_cache(self, cache_config):
369369
super().reset_cache(cache_config)
370370

371371
def state_dict_keyname(self):
372-
return super().state_dict()[0].keys()
372+
return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()})
373373

374374
def load_state_dict(self, state_dict, strict=None):
375-
for name, param in state_dict.items():
376-
super().load_param(name, param._underlying)
375+
super().load_params({name: param._underlying for name, param in state_dict.items()})
377376

378377
def process_weights_after_loading(self):
379378
super().process_weights_after_loading()

0 commit comments

Comments
 (0)