77#include < iostream>
88#include < spdlog/spdlog.h>
99#include < stdexcept>
10+ #include < string>
1011
1112namespace infinilm ::engine {
1213
14+ namespace {
15+
16+ infinicore::Tensor clone_tensor_for_weight_load (const infinicore::Tensor &tensor) {
17+ auto cloned = infinicore::Tensor::empty (
18+ tensor->shape (),
19+ tensor->dtype (),
20+ tensor->device (),
21+ false );
22+ cloned->copy_from (tensor);
23+ return cloned;
24+ }
25+
26+ std::unordered_map<std::string, infinicore::Tensor> clone_params_for_weight_load (
27+ const std::unordered_map<std::string, infinicore::Tensor> ¶ms,
28+ bool clone_weights) {
29+ if (!clone_weights) {
30+ return params;
31+ }
32+
33+ std::unordered_map<std::string, infinicore::Tensor> cloned_params;
34+ cloned_params.reserve (params.size ());
35+ for (const auto &[name, tensor] : params) {
36+ cloned_params.emplace (name, clone_tensor_for_weight_load (tensor));
37+ }
38+ return cloned_params;
39+ }
40+
41+ } // namespace
42+
1343RankWorker::RankWorker (
1444 std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
1545 const distributed::RankInfo &rank_info,
@@ -91,26 +121,37 @@ void RankWorker::load_param(const std::string &name,
91121// ------------------------------------------------------
92122// load_params -- synchronous batch load
93123// ------------------------------------------------------
94- void RankWorker::load_params (const std::unordered_map<std::string, infinicore::Tensor> ¶ms) {
124+ void RankWorker::load_params (const std::unordered_map<std::string, infinicore::Tensor> ¶ms, bool clone_weights) {
125+ load_params_async (params, clone_weights);
126+
127+ std::unique_lock<std::mutex> lk (mutex_);
128+ cv_.wait (lk, [&] { return job_done_ || should_exit_; });
129+
130+ if (should_exit_) {
131+ throw std::runtime_error (" RankWorker stopped while loading parameters" );
132+ }
133+ }
134+
135+ // ------------------------------------------------------
136+ // load_params_async -- submit batch load without waiting
137+ // ------------------------------------------------------
138+ void RankWorker::load_params_async (const std::unordered_map<std::string, infinicore::Tensor> ¶ms, bool clone_weights) {
95139 {
96140 std::lock_guard<std::mutex> lock (mutex_);
97141 if (should_exit_) {
98- throw std::runtime_error (" RankWorker is closing; cannot load_params" );
142+ throw std::runtime_error (" RankWorker is closing; cannot load_params_async" );
143+ }
144+ if (has_job_ && !job_done_) {
145+ throw std::runtime_error (" RankWorker already has a pending job" );
99146 }
100147
101148 pending_params_ = params;
149+ pending_weight_load_clone_ = clone_weights;
102150 job_cmd_ = Command::LOAD_BATCH ;
103151 has_job_ = true ;
104152 job_done_ = false ;
105153 }
106154 cv_.notify_all ();
107-
108- std::unique_lock<std::mutex> lk (mutex_);
109- cv_.wait (lk, [&] { return job_done_ || should_exit_; });
110-
111- if (should_exit_) {
112- throw std::runtime_error (" RankWorker stopped while loading parameters" );
113- }
114155}
115156
116157// ------------------------------------------------------
@@ -292,6 +333,7 @@ void RankWorker::thread_loop() {
292333 std::string local_param_name;
293334 infinicore::Tensor local_param;
294335 std::unordered_map<std::string, infinicore::Tensor> local_params;
336+ bool local_weight_load_clone = false ;
295337 Input local_args;
296338 std::unique_ptr<cache::CacheConfig> local_cache_config;
297339
@@ -311,6 +353,8 @@ void RankWorker::thread_loop() {
311353 local_param = pending_param_;
312354 } else if (local_cmd == Command::LOAD_BATCH ) {
313355 local_params = std::move (pending_params_);
356+ local_weight_load_clone = pending_weight_load_clone_;
357+ pending_weight_load_clone_ = false ;
314358 pending_params_.clear ();
315359 } else if (local_cmd == Command::PREPROCESS ) {
316360
@@ -350,7 +394,8 @@ void RankWorker::thread_loop() {
350394
351395 } else if (local_cmd == Command::LOAD_BATCH ) {
352396 try {
353- model_->load_parameters_no_sync (local_params);
397+ auto params_for_load = clone_params_for_weight_load (local_params, local_weight_load_clone);
398+ model_->load_parameters_no_sync (params_for_load);
354399 infinicore::context::syncStream ();
355400 } catch (const std::exception &e) {
356401 {
0 commit comments