|
2 | 2 |
|
3 | 3 | #include "../models/model_factory.hpp" |
4 | 4 |
|
| 5 | +#include "infinicore/ops.hpp" |
| 6 | + |
5 | 7 | #include <iostream> |
6 | 8 | #include <spdlog/spdlog.h> |
7 | 9 | #include <stdexcept> |
@@ -95,7 +97,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic |
95 | 97 | //------------------------------------------------------ |
96 | 98 | // run -- asynchronous |
97 | 99 | //------------------------------------------------------ |
98 | | -void RankWorker::run(const InfinilmModel::Input &args) { |
| 100 | +void RankWorker::run(const Input &args) { |
99 | 101 | std::lock_guard<std::mutex> lock(mutex_); |
100 | 102 |
|
101 | 103 | if (should_exit_) { |
@@ -156,7 +158,7 @@ void RankWorker::close() { |
156 | 158 | //------------------------------------------------------ |
157 | 159 | // get_output (thread safe) |
158 | 160 | //------------------------------------------------------ |
159 | | -InfinilmModel::Output RankWorker::get_output() { |
| 161 | +RankWorker::Output RankWorker::get_output() { |
160 | 162 | std::lock_guard<std::mutex> lock(mutex_); |
161 | 163 | return output_; |
162 | 164 | } |
@@ -204,7 +206,7 @@ void RankWorker::thread_loop() { |
204 | 206 | local_param_name = pending_param_name_; |
205 | 207 | local_param = pending_param_; |
206 | 208 | } else if (local_cmd == Command::RUN) { |
207 | | - local_args = pending_args_; |
| 209 | + local_args = pending_args_.to_model_input(); |
208 | 210 | } else if (local_cmd == Command::RESET_CACHE) { |
209 | 211 | if (pending_cache_config_ != nullptr) { |
210 | 212 | local_cache_config = pending_cache_config_->unique_copy(); |
@@ -239,12 +241,40 @@ void RankWorker::thread_loop() { |
239 | 241 |
|
240 | 242 | } else if (local_cmd == Command::RUN) { |
241 | 243 | try { |
242 | | - auto out = model_->forward(local_args); |
243 | | - infinicore::context::syncStream(); |
244 | | - |
245 | 244 | { |
246 | 245 | std::lock_guard<std::mutex> lk(mutex_); |
247 | | - output_ = std::move(out); |
| 246 | + |
| 247 | + auto logits{model_->forward(local_args).logits}; |
| 248 | + |
| 249 | + if (rank_info_.tp_rank == 0) { |
| 250 | + // Perform random sampling. |
| 251 | + auto temperature{pending_args_.temperature}; |
| 252 | + auto top_p{pending_args_.top_p}; |
| 253 | + auto top_k{pending_args_.top_k}; |
| 254 | + auto random_val{pending_args_.random_val}; |
| 255 | + |
| 256 | + const auto &logits_shape{logits->shape()}; |
| 257 | + const auto &batch_size{logits_shape[0]}; |
| 258 | + const auto &vocab_size{logits_shape[2]}; |
| 259 | + |
| 260 | + auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)}; |
| 261 | + |
| 262 | + for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) { |
| 263 | + auto score{logits->narrow({{0, i, 1}})->view({vocab_size})}; |
| 264 | + auto out{output_ids->narrow({{0, i, 1}})->view({})}; |
| 265 | + infinicore::op::random_sample_( |
| 266 | + out, score, random_val, top_p, top_k, temperature); |
| 267 | + } |
| 268 | + |
| 269 | + output_ids = output_ids->to(infinicore::Device::cpu()); |
| 270 | + |
| 271 | + infinicore::context::syncStream(); |
| 272 | + |
| 273 | + auto out{Output{output_ids}}; |
| 274 | + |
| 275 | + output_ = std::move(out); |
| 276 | + } |
| 277 | + |
248 | 278 | job_done_ = true; |
249 | 279 | } |
250 | 280 | cv_.notify_all(); |
|
0 commit comments