@@ -88,6 +88,31 @@ void RankWorker::load_param(const std::string &name,
8888 }
8989}
9090
91+ // ------------------------------------------------------
92+ // load_params -- synchronous batch load
93+ // ------------------------------------------------------
94+ void RankWorker::load_params (const std::unordered_map<std::string, infinicore::Tensor> ¶ms) {
95+ {
96+ std::lock_guard<std::mutex> lock (mutex_);
97+ if (should_exit_) {
98+ throw std::runtime_error (" RankWorker is closing; cannot load_params" );
99+ }
100+
101+ pending_params_ = params;
102+ job_cmd_ = Command::LOAD_BATCH ;
103+ has_job_ = true ;
104+ job_done_ = false ;
105+ }
106+ cv_.notify_all ();
107+
108+ std::unique_lock<std::mutex> lk (mutex_);
109+ cv_.wait (lk, [&] { return job_done_ || should_exit_; });
110+
111+ if (should_exit_) {
112+ throw std::runtime_error (" RankWorker stopped while loading parameters" );
113+ }
114+ }
115+
91116// ------------------------------------------------------
92117// process_weights_after_loading -- asynchronous
93118// ------------------------------------------------------
@@ -266,6 +291,7 @@ void RankWorker::thread_loop() {
266291 Command local_cmd = Command::INIT ;
267292 std::string local_param_name;
268293 infinicore::Tensor local_param;
294+ std::unordered_map<std::string, infinicore::Tensor> local_params;
269295 Input local_args;
270296 std::unique_ptr<cache::CacheConfig> local_cache_config;
271297
@@ -283,6 +309,9 @@ void RankWorker::thread_loop() {
283309 if (local_cmd == Command::LOAD ) {
284310 local_param_name = pending_param_name_;
285311 local_param = pending_param_;
312+ } else if (local_cmd == Command::LOAD_BATCH ) {
313+ local_params = std::move (pending_params_);
314+ pending_params_.clear ();
286315 } else if (local_cmd == Command::PREPROCESS ) {
287316
288317 } else if (local_cmd == Command::RUN ) {
@@ -319,6 +348,27 @@ void RankWorker::thread_loop() {
319348 }
320349 cv_.notify_all ();
321350
351+ } else if (local_cmd == Command::LOAD_BATCH ) {
352+ try {
353+ model_->load_parameters_no_sync (local_params);
354+ infinicore::context::syncStream ();
355+ } catch (const std::exception &e) {
356+ {
357+ std::lock_guard<std::mutex> lk (mutex_);
358+ should_exit_ = true ;
359+ job_done_ = true ;
360+ }
361+ cv_.notify_all ();
362+ spdlog::error (" [{}] exception during load_parameters_: {}\n " , info (), e.what ());
363+ break ;
364+ }
365+
366+ {
367+ std::lock_guard<std::mutex> lk (mutex_);
368+ job_done_ = true ;
369+ }
370+ cv_.notify_all ();
371+
322372 } else if (local_cmd == Command::PREPROCESS ) {
323373 // Handle preprocess command
324374 try {
0 commit comments