Skip to content

Commit e2aefa2

Browse files
committed
feat: support different weight load options
1 parent 65817cf commit e2aefa2

14 files changed

Lines changed: 339 additions & 64 deletions

File tree

csrc/engine/infer_engine.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#include "infer_engine.hpp"
22
#include "../config/config_factory.hpp"
33
#include "spdlog/spdlog.h"
4-
#include <future>
4+
#include <algorithm>
5+
#include <stdexcept>
6+
#include <string>
57

68
namespace infinilm::engine {
79

@@ -15,8 +17,16 @@ InferEngine::InferEngine(
1517
const cache::CacheConfig *cache_config,
1618
bool enable_graph_compiling,
1719
backends::AttentionBackend attention_backend,
18-
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
19-
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
20+
std::optional<infinicore::DataType> kv_cache_dtype,
21+
const std::string &weight_load_mode) // Changed parameter
22+
: communication_group_(distributed_config, device_type),
23+
attention_backend_(attention_backend),
24+
weight_load_mode_(weight_load_mode),
25+
weight_load_group_size_(2),
26+
weight_load_clone_(weight_load_mode == "grouped-clone") {
27+
if (weight_load_mode_ != "sync" && weight_load_mode_ != "async" && weight_load_mode_ != "grouped" && weight_load_mode_ != "grouped-clone") {
28+
throw std::invalid_argument("weight_load_mode must be one of: sync, async, grouped, grouped-clone");
29+
}
2030
if (cache_config != nullptr) {
2131
cache_config_ = cache_config->unique_copy();
2232
}
@@ -57,15 +67,32 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor &
5767
}
5868

5969
void InferEngine::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
60-
std::vector<std::future<void>> futures;
61-
futures.reserve(workers_.size());
62-
for (auto &worker : workers_) {
63-
futures.emplace_back(std::async(std::launch::async, [&worker, &params] {
64-
worker->load_params(params);
65-
}));
70+
if (workers_.size() <= 1 || weight_load_mode_ == "sync") {
71+
for (auto &worker : workers_) {
72+
worker->load_params(params, weight_load_clone_);
73+
}
74+
return;
75+
}
76+
77+
if (weight_load_mode_ == "async") {
78+
for (auto &worker : workers_) {
79+
worker->load_params_async(params, weight_load_clone_);
80+
}
81+
for (auto &worker : workers_) {
82+
worker->wait();
83+
}
84+
return;
6685
}
67-
for (auto &future : futures) {
68-
future.get();
86+
87+
const size_t group_size = std::max<size_t>(1, std::min(weight_load_group_size_, workers_.size()));
88+
for (size_t group_start = 0; group_start < workers_.size(); group_start += group_size) {
89+
const size_t group_end = std::min(group_start + group_size, workers_.size());
90+
for (size_t i = group_start; i < group_end; ++i) {
91+
workers_[i]->load_params_async(params, weight_load_clone_);
92+
}
93+
for (size_t i = group_start; i < group_end; ++i) {
94+
workers_[i]->wait();
95+
}
6996
}
7097
}
7198

csrc/engine/infer_engine.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "rank_worker.hpp"
1010

1111
#include <optional>
12+
#include <string>
1213
#include <unordered_map>
1314
#include <vector>
1415

@@ -28,7 +29,8 @@ class InferEngine {
2829
const cache::CacheConfig *cache_config = nullptr,
2930
bool enable_graph_compiling = false,
3031
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
31-
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
32+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt,
33+
const std::string &weight_load_mode = "async");
3234

3335
// Load a parameter to all workers (each can extract its shard inside RankWorker)
3436
void load_param(const std::string &name, const infinicore::Tensor &param);
@@ -63,6 +65,9 @@ class InferEngine {
6365
std::unique_ptr<cache::CacheConfig> cache_config_;
6466
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
6567
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
68+
std::string weight_load_mode_ = "async";
69+
size_t weight_load_group_size_ = 2;
70+
bool weight_load_clone_ = false;
6671
};
6772

6873
} // namespace infinilm::engine

csrc/engine/rank_worker.cpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,39 @@
77
#include <iostream>
88
#include <spdlog/spdlog.h>
99
#include <stdexcept>
10+
#include <string>
1011

1112
namespace infinilm::engine {
1213

14+
namespace {
15+
16+
infinicore::Tensor clone_tensor_for_weight_load(const infinicore::Tensor &tensor) {
17+
auto cloned = infinicore::Tensor::empty(
18+
tensor->shape(),
19+
tensor->dtype(),
20+
tensor->device(),
21+
false);
22+
cloned->copy_from(tensor);
23+
return cloned;
24+
}
25+
26+
std::unordered_map<std::string, infinicore::Tensor> clone_params_for_weight_load(
27+
const std::unordered_map<std::string, infinicore::Tensor> &params,
28+
bool clone_weights) {
29+
if (!clone_weights) {
30+
return params;
31+
}
32+
33+
std::unordered_map<std::string, infinicore::Tensor> cloned_params;
34+
cloned_params.reserve(params.size());
35+
for (const auto &[name, tensor] : params) {
36+
cloned_params.emplace(name, clone_tensor_for_weight_load(tensor));
37+
}
38+
return cloned_params;
39+
}
40+
41+
} // namespace
42+
1343
RankWorker::RankWorker(
1444
std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
1545
const distributed::RankInfo &rank_info,
@@ -91,26 +121,37 @@ void RankWorker::load_param(const std::string &name,
91121
//------------------------------------------------------
92122
// load_params -- synchronous batch load
93123
//------------------------------------------------------
94-
void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params) {
124+
void RankWorker::load_params(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights) {
125+
load_params_async(params, clone_weights);
126+
127+
std::unique_lock<std::mutex> lk(mutex_);
128+
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
129+
130+
if (should_exit_) {
131+
throw std::runtime_error("RankWorker stopped while loading parameters");
132+
}
133+
}
134+
135+
//------------------------------------------------------
136+
// load_params_async -- submit batch load without waiting
137+
//------------------------------------------------------
138+
void RankWorker::load_params_async(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights) {
95139
{
96140
std::lock_guard<std::mutex> lock(mutex_);
97141
if (should_exit_) {
98-
throw std::runtime_error("RankWorker is closing; cannot load_params");
142+
throw std::runtime_error("RankWorker is closing; cannot load_params_async");
143+
}
144+
if (has_job_ && !job_done_) {
145+
throw std::runtime_error("RankWorker already has a pending job");
99146
}
100147

101148
pending_params_ = params;
149+
pending_weight_load_clone_ = clone_weights;
102150
job_cmd_ = Command::LOAD_BATCH;
103151
has_job_ = true;
104152
job_done_ = false;
105153
}
106154
cv_.notify_all();
107-
108-
std::unique_lock<std::mutex> lk(mutex_);
109-
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
110-
111-
if (should_exit_) {
112-
throw std::runtime_error("RankWorker stopped while loading parameters");
113-
}
114155
}
115156

116157
//------------------------------------------------------
@@ -292,6 +333,7 @@ void RankWorker::thread_loop() {
292333
std::string local_param_name;
293334
infinicore::Tensor local_param;
294335
std::unordered_map<std::string, infinicore::Tensor> local_params;
336+
bool local_weight_load_clone = false;
295337
Input local_args;
296338
std::unique_ptr<cache::CacheConfig> local_cache_config;
297339

@@ -311,6 +353,8 @@ void RankWorker::thread_loop() {
311353
local_param = pending_param_;
312354
} else if (local_cmd == Command::LOAD_BATCH) {
313355
local_params = std::move(pending_params_);
356+
local_weight_load_clone = pending_weight_load_clone_;
357+
pending_weight_load_clone_ = false;
314358
pending_params_.clear();
315359
} else if (local_cmd == Command::PREPROCESS) {
316360

@@ -350,7 +394,8 @@ void RankWorker::thread_loop() {
350394

351395
} else if (local_cmd == Command::LOAD_BATCH) {
352396
try {
353-
model_->load_parameters_no_sync(local_params);
397+
auto params_for_load = clone_params_for_weight_load(local_params, local_weight_load_clone);
398+
model_->load_parameters_no_sync(params_for_load);
354399
infinicore::context::syncStream();
355400
} catch (const std::exception &e) {
356401
{

csrc/engine/rank_worker.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ class RankWorker {
8585
void load_param(const std::string &name,
8686
const infinicore::Tensor &param);
8787

88-
void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params);
88+
void load_params(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights = false);
89+
90+
void load_params_async(const std::unordered_map<std::string, infinicore::Tensor> &params, bool clone_weights = false);
8991

9092
void process_weights_after_loading();
9193

@@ -144,6 +146,7 @@ class RankWorker {
144146
std::string pending_param_name_;
145147
infinicore::Tensor pending_param_;
146148
std::unordered_map<std::string, infinicore::Tensor> pending_params_;
149+
bool pending_weight_load_clone_ = false;
147150
Input pending_args_;
148151
std::unique_ptr<cache::CacheConfig> pending_cache_config_;
149152

csrc/pybind11/engine/engine.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,26 @@ inline void bind_infer_engine(py::module &m) {
3939
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
4040
bool enable_graph_compiling,
4141
const std::string &attention_backend,
42-
std::optional<infinicore::DataType> kv_cache_dtype) {
42+
std::optional<infinicore::DataType> kv_cache_dtype,
43+
const std::string &weight_load_mode) {
4344
return std::make_shared<InferEngine>(
4445
model_path,
4546
dist,
4647
dev,
4748
cache_cfg ? cache_cfg.get() : nullptr,
4849
enable_graph_compiling,
4950
infinilm::backends::parse_attention_backend(attention_backend),
50-
kv_cache_dtype);
51+
kv_cache_dtype,
52+
weight_load_mode);
5153
}),
5254
py::arg("model_path") = "",
5355
py::arg("distributed_config") = distributed::DistConfig(),
5456
py::arg("device_type") = infinicore::context::getDevice().getType(),
5557
py::arg("cache_config") = py::none(),
5658
py::arg("enable_graph_compiling") = false,
5759
py::arg("attention_backend") = "default",
58-
py::arg("kv_cache_dtype") = py::none())
60+
py::arg("kv_cache_dtype") = py::none(),
61+
py::arg("weight_load_mode") = "async")
5962
.def("load_param", &InferEngine::load_param,
6063
py::arg("name"), py::arg("param"),
6164
"Load a parameter tensor into all workers (each worker picks its shard)")

examples/bench.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
cache_config=None,
170170
enable_graph=False,
171171
attn_backend="default",
172+
weight_load_mode="async",
172173
) -> None:
173174
model_path = os.path.expanduser(model_path)
174175
# ---------------------------------------------------------------------------- #
@@ -182,6 +183,7 @@ def __init__(
182183
enable_graph_compiling=enable_graph,
183184
attention_backend=attn_backend,
184185
kv_cache_dtype=cfg.kv_cache_dtype,
186+
weight_load_mode=weight_load_mode,
185187
)
186188

187189
# ---------------------------------------------------------------------------- #
@@ -322,6 +324,7 @@ def run(
322324
cache_config=cache_config,
323325
enable_graph=enable_graph,
324326
attn_backend=attn_backend,
327+
weight_load_mode=cfg.weight_load_mode,
325328
)
326329

327330
# ---------------------------------------------------------------------------- #

examples/test_infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test(
1818
attn_backend="default",
1919
image_path=None,
2020
skip_load=False,
21+
weight_load_mode="async",
2122
):
2223
model_path = os.path.expanduser(model_path)
2324
# ---------------------------------------------------------------------------- #
@@ -39,6 +40,7 @@ def test(
3940
enable_graph=enable_graph,
4041
attn_backend=attn_backend,
4142
skip_load=skip_load,
43+
weight_load_mode=weight_load_mode,
4244
)
4345

4446
conversations = [
@@ -103,4 +105,5 @@ def test(
103105
attn_backend=cfg.attn,
104106
image_path=cfg.image,
105107
skip_load=cfg.skip_load,
108+
weight_load_mode=cfg.weight_load_mode,
106109
)

python/infinilm/base_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class BaseConfig:
4444
"""InfiniLM Unified Config - Command line argument parser"""
4545

4646
def __init__(self):
47-
4847
self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config")
4948
self._add_common_args()
5049
self.args, self.extra = self.parser.parse_known_args()
@@ -67,6 +66,7 @@ def __init__(self):
6766
self.max_cache_len = self.args.max_cache_len
6867
self.kv_cache_dtype = self.args.kv_cache_dtype
6968
self.skip_load = self.args.skip_load
69+
self.weight_load_mode = self.args.weight_load_mode
7070

7171
self.batch_size = self.args.batch_size
7272
self.max_batch_size = self.args.max_batch_size
@@ -146,6 +146,13 @@ def _add_common_args(self):
146146
self.parser.add_argument(
147147
"--skip-load", action="store_true", help="skip loading model weights"
148148
)
149+
self.parser.add_argument(
150+
"--weight-load-mode",
151+
type=str,
152+
default="async",
153+
choices=["async", "sync", "grouped", "grouped-clone"],
154+
help="weight loading mode: async keeps old behavior; grouped-clone is the stable 103B option",
155+
)
149156

150157
# --- Length and infer parameters ---
151158
self.parser.add_argument("--batch-size", type=int, default=1)

python/infinilm/infer_engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def read_hf_config(model_path):
2424
)
2525
return config_dict
2626

27+
2728
# config.json (required) defines model architecture, while generation_config.json
2829
# (optional) defines generation behavior. They are kept as separate readers
2930
# because: 1) config.json must exist and requires model_type validation,
@@ -37,6 +38,7 @@ def read_hf_generation_config(model_path):
3738
return json.load(f)
3839
return {}
3940

41+
4042
@dataclass
4143
class GenerationConfig:
4244
max_new_tokens: int | None = None
@@ -59,6 +61,7 @@ def __init__(
5961
enable_graph_compiling=False,
6062
attention_backend="default",
6163
kv_cache_dtype=None,
64+
weight_load_mode="async",
6265
):
6366
self.hf_config = read_hf_config(model_path)
6467
self.hf_generation_config = read_hf_generation_config(model_path)
@@ -79,6 +82,7 @@ def __init__(
7982
if kv_cache_dtype is not None
8083
else None
8184
),
85+
weight_load_mode,
8286
)
8387
self.use_cache = False
8488

@@ -369,10 +373,14 @@ def reset_cache(self, cache_config):
369373
super().reset_cache(cache_config)
370374

371375
def state_dict_keyname(self):
372-
return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()})
376+
return sorted(
377+
{name for state_dict in super().state_dict() for name in state_dict.keys()}
378+
)
373379

374380
def load_state_dict(self, state_dict, strict=None):
375-
super().load_params({name: param._underlying for name, param in state_dict.items()})
381+
super().load_params(
382+
{name: param._underlying for name, param in state_dict.items()}
383+
)
376384

377385
def process_weights_after_loading(self):
378386
super().process_weights_after_loading()

python/infinilm/llm/cache_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def allocate_blocks(
140140
"""
141141
if block_table is None:
142142
block_table = []
143+
if mm_token_index_mappings is None:
144+
mm_token_index_mappings = []
143145

144146
# Static args
145147
num_tokens = len(token_ids)

0 commit comments

Comments
 (0)