Skip to content

Commit 96e53db

Browse files
authored
issue/160: 梳理 InferEngine 相关接口
* 将 `cpp.LlamaForCausalLM` 提出,变为 `infinilm.infer_engine.InferEngine` * 将 `Config` 构造逻辑拆分至 `AutoConfig` 中 * 在 `examples` 脚本中直接构造 `InferEngine` * 将 `random_sample` 计算放入模型中 * 为 `InferEngine` 单独实现 `generate` * 允许通过 `GenerationConfig` 传递 `temperature`、`top_k`、`top_p` * 将 `random_sample` 处理从 `LlamaForCausalLM` 中转移到 `RankWorker` 里 * 在 `InferEngine.generate` 中直接 `append(output_id)` * 修复 commit `13aa90c57de369f9985593c0066b6b06a7508b24` 引入的分布式卡死问题 * 将 `InferEngine.forward` 的接口与 C++ 层的 `InferEngine.Input` 对齐 * 提供了 `_measure_and_log_time` 参数来开启之前的 `generate` 内部计时功能
1 parent 23b1306 commit 96e53db

File tree

14 files changed

+383
-257
lines changed

14 files changed

+383
-257
lines changed

csrc/engine/infer_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
6363
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
6464
// Trigger each worker to run inference
6565
for (auto &worker : workers_) {
66-
worker->run(input.to_model_input());
66+
worker->run(input);
6767
}
6868
// Wait for all workers
6969
for (auto &worker : workers_) {
7070
worker->wait();
7171
}
7272

73-
return {workers_[0]->get_output().logits};
73+
return workers_[0]->get_output();
7474
}
7575

7676
//------------------------------------------------------

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,9 @@ namespace infinilm::engine {
1313

1414
class InferEngine {
1515
public:
16-
struct Input {
17-
/// Token IDs tensor of shape `[batch, seq_len]`.
18-
std::optional<infinicore::Tensor> input_ids;
19-
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
20-
std::optional<infinicore::Tensor> position_ids;
21-
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
22-
std::optional<infinicore::Tensor> cache_lengths;
23-
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
24-
std::optional<infinicore::Tensor> input_lengths;
25-
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
26-
std::optional<infinicore::Tensor> input_offsets;
27-
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
28-
std::optional<infinicore::Tensor> block_tables;
29-
/// Slot ids for each token `[seq]`. Used for paged cache.
30-
std::optional<infinicore::Tensor> slot_mapping;
16+
using Input = RankWorker::Input;
3117

32-
infinilm::InfinilmModel::Input to_model_input() const;
33-
};
34-
35-
struct Output {
36-
infinicore::Tensor logits;
37-
};
18+
using Output = RankWorker::Output;
3819

3920
// Updated constructor: accept CacheConfig instead of CacheType
4021
InferEngine(

csrc/engine/rank_worker.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "../models/model_factory.hpp"
44

5+
#include "infinicore/ops.hpp"
6+
57
#include <iostream>
68
#include <spdlog/spdlog.h>
79
#include <stdexcept>
@@ -95,7 +97,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic
9597
//------------------------------------------------------
9698
// run -- asynchronous
9799
//------------------------------------------------------
98-
void RankWorker::run(const InfinilmModel::Input &args) {
100+
void RankWorker::run(const Input &args) {
99101
std::lock_guard<std::mutex> lock(mutex_);
100102

101103
if (should_exit_) {
@@ -156,7 +158,7 @@ void RankWorker::close() {
156158
//------------------------------------------------------
157159
// get_output (thread safe)
158160
//------------------------------------------------------
159-
InfinilmModel::Output RankWorker::get_output() {
161+
RankWorker::Output RankWorker::get_output() {
160162
std::lock_guard<std::mutex> lock(mutex_);
161163
return output_;
162164
}
@@ -204,7 +206,7 @@ void RankWorker::thread_loop() {
204206
local_param_name = pending_param_name_;
205207
local_param = pending_param_;
206208
} else if (local_cmd == Command::RUN) {
207-
local_args = pending_args_;
209+
local_args = pending_args_.to_model_input();
208210
} else if (local_cmd == Command::RESET_CACHE) {
209211
if (pending_cache_config_ != nullptr) {
210212
local_cache_config = pending_cache_config_->unique_copy();
@@ -239,12 +241,40 @@ void RankWorker::thread_loop() {
239241

240242
} else if (local_cmd == Command::RUN) {
241243
try {
242-
auto out = model_->forward(local_args);
243-
infinicore::context::syncStream();
244-
245244
{
246245
std::lock_guard<std::mutex> lk(mutex_);
247-
output_ = std::move(out);
246+
247+
auto logits{model_->forward(local_args).logits};
248+
249+
if (rank_info_.tp_rank == 0) {
250+
// Perform random sampling.
251+
auto temperature{pending_args_.temperature};
252+
auto top_p{pending_args_.top_p};
253+
auto top_k{pending_args_.top_k};
254+
auto random_val{pending_args_.random_val};
255+
256+
const auto &logits_shape{logits->shape()};
257+
const auto &batch_size{logits_shape[0]};
258+
const auto &vocab_size{logits_shape[2]};
259+
260+
auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)};
261+
262+
for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) {
263+
auto score{logits->narrow({{0, i, 1}})->view({vocab_size})};
264+
auto out{output_ids->narrow({{0, i, 1}})->view({})};
265+
infinicore::op::random_sample_(
266+
out, score, random_val, top_p, top_k, temperature);
267+
}
268+
269+
output_ids = output_ids->to(infinicore::Device::cpu());
270+
271+
infinicore::context::syncStream();
272+
273+
auto out{Output{output_ids}};
274+
275+
output_ = std::move(out);
276+
}
277+
248278
job_done_ = true;
249279
}
250280
cv_.notify_all();

csrc/engine/rank_worker.hpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,37 @@ class RankWorker {
2323
};
2424

2525
public:
26+
struct Input {
27+
/// Token IDs tensor of shape `[batch, seq_len]`.
28+
std::optional<infinicore::Tensor> input_ids;
29+
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
30+
std::optional<infinicore::Tensor> position_ids;
31+
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
32+
std::optional<infinicore::Tensor> cache_lengths;
33+
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
34+
std::optional<infinicore::Tensor> input_lengths;
35+
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
36+
std::optional<infinicore::Tensor> input_offsets;
37+
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
38+
std::optional<infinicore::Tensor> block_tables;
39+
/// Slot ids for each token `[seq]`. Used for paged cache.
40+
std::optional<infinicore::Tensor> slot_mapping;
41+
42+
float temperature{1};
43+
44+
int top_k{50};
45+
46+
float top_p{1};
47+
48+
float random_val{0.1};
49+
50+
infinilm::InfinilmModel::Input to_model_input() const;
51+
};
52+
53+
struct Output {
54+
infinicore::Tensor output_ids;
55+
};
56+
2657
RankWorker(const InfinilmModel::Config &model_config,
2758
const distributed::RankInfo &rank_info,
2859
const cache::CacheConfig *cache_config);
@@ -35,7 +66,7 @@ class RankWorker {
3566
std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();
3667

3768
// Submit a run (forward) job.
38-
void run(const InfinilmModel::Input &args);
69+
void run(const Input &args);
3970

4071
// Reset the internal cache with a new configuration
4172
void reset_cache(const cache::CacheConfig *new_config);
@@ -47,7 +78,7 @@ class RankWorker {
4778
void close();
4879

4980
// Thread-safe accessor for last output produced by RUN.
50-
InfinilmModel::Output get_output();
81+
Output get_output();
5182

5283
std::string info() const;
5384

@@ -73,11 +104,11 @@ class RankWorker {
73104
// Task payloads (protected by mutex)
74105
std::string pending_param_name_;
75106
infinicore::Tensor pending_param_;
76-
InfinilmModel::Input pending_args_;
107+
Input pending_args_;
77108
std::unique_ptr<cache::CacheConfig> pending_cache_config_;
78109

79110
// Output (protected by mutex)
80-
InfinilmModel::Output output_;
111+
Output output_;
81112

82113
// Thread sync
83114
std::thread thread_;

csrc/models/infinilm_model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class InfinilmModel : public infinicore::nn::Module {
3535
};
3636

3737
struct Output {
38-
/// Output tensor of shape [batch, seq_len, vocab_size].
38+
/// Logits.
3939
infinicore::Tensor logits;
4040
};
4141

csrc/pybind11/engine/engine.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,28 @@ inline void bind_infer_engine(py::module &m) {
8484
std::optional<infinicore::Tensor> input_lengths,
8585
std::optional<infinicore::Tensor> input_offsets,
8686
std::optional<infinicore::Tensor> block_tables,
87-
std::optional<infinicore::Tensor> slot_mapping) {
88-
return InferEngine::Input{
87+
std::optional<infinicore::Tensor> slot_mapping,
88+
py::kwargs kwargs) {
89+
auto input{InferEngine::Input{
8990
std::move(input_ids),
9091
std::move(position_ids),
9192
std::move(cache_lengths),
9293
std::move(block_tables),
93-
std::move(slot_mapping)};
94+
std::move(slot_mapping)}};
95+
96+
if (kwargs) {
97+
if (kwargs.contains("temperature")) {
98+
input.temperature = kwargs["temperature"].cast<float>();
99+
}
100+
if (kwargs.contains("top_k")) {
101+
input.top_k = kwargs["top_k"].cast<int>();
102+
}
103+
if (kwargs.contains("top_p")) {
104+
input.top_p = kwargs["top_p"].cast<float>();
105+
}
106+
}
107+
108+
return input;
94109
}),
95110
py::arg("input_ids") = std::nullopt,
96111
py::arg("position_ids") = std::nullopt,
@@ -108,7 +123,7 @@ inline void bind_infer_engine(py::module &m) {
108123
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping);
109124

110125
py::class_<InferEngine::Output>(infer_engine, "Output")
111-
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
126+
.def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
112127
}
113128

114129
} // namespace infinilm::engine

examples/bench.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import infinicore
22
from transformers import AutoTokenizer
33
from infinilm.modeling_utils import load_model_state_dict_by_file
4-
import infinilm
54
from infinilm.distributed import DistConfig
5+
from infinilm.infer_engine import GenerationConfig, InferEngine
66
import argparse
77
import sys
88
import time
99
import os
1010
import json
1111
from collections import OrderedDict
12+
import numpy as np
1213
from tqdm import tqdm
1314

1415
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
@@ -205,10 +206,9 @@ def __init__(
205206
# ---------------------------------------------------------------------------- #
206207
# 创建模型,
207208
# ---------------------------------------------------------------------------- #
208-
model = infinilm.AutoLlamaModel.from_pretrained(
209+
model = InferEngine(
209210
model_path,
210211
device=infini_device,
211-
backend="cpp",
212212
distributed_config=DistConfig(tp),
213213
)
214214

@@ -257,14 +257,17 @@ def run(
257257

258258
t1 = time.time()
259259
print("=================== start generate ====================")
260-
self.model.generate(
260+
output_ids = self.model.generate(
261261
input_ids_infini,
262-
max_new_tokens=output_len,
263-
tokenizer=self.tokenizer,
264-
stop_on_eos=False,
262+
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
265263
)
266264
t2 = time.time()
267265

266+
numpy_output_ids = np.array(
267+
[output_id.to_numpy()[0] for output_id in output_ids]
268+
)
269+
print(self.tokenizer.decode(numpy_output_ids, skip_special_tokens=True))
270+
268271
print(
269272
f"total_time: {round((t2 - t1) * 1000, 2)} ms",
270273
)

examples/jiuge.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from transformers import AutoTokenizer
33
from tokenizers import decoders as _dec
44
from infinilm.modeling_utils import load_model_state_dict_by_file
5-
import infinilm
65
from infinilm.distributed import DistConfig
6+
from infinilm.infer_engine import GenerationConfig, InferEngine
77
import argparse
88
import sys
99
import time
1010
import os
11+
import numpy as np
1112

1213
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
1314

@@ -90,17 +91,15 @@ def test(
9091
model_path,
9192
max_new_tokens=100,
9293
infini_device=infinicore.device("cpu", 0),
93-
backend="python",
9494
tp=1,
9595
):
9696
model_path = os.path.expanduser(model_path)
9797
# ---------------------------------------------------------------------------- #
9898
# 创建模型,
9999
# ---------------------------------------------------------------------------- #
100-
model = infinilm.AutoLlamaModel.from_pretrained(
100+
model = InferEngine(
101101
model_path,
102102
device=infini_device,
103-
backend=backend,
104103
distributed_config=DistConfig(tp),
105104
)
106105

@@ -165,13 +164,18 @@ def test(
165164

166165
t1 = time.time()
167166
print("=================== start generate ====================")
168-
model.generate(
167+
output_ids = model.generate(
169168
input_ids_infini,
170-
max_new_tokens=max_new_tokens,
171-
tokenizer=tokenizer,
169+
GenerationConfig(
170+
max_new_tokens=max_new_tokens, temperature=1, top_k=1, top_p=0.8
171+
),
172+
_measure_and_log_time=True,
172173
)
173174
t2 = time.time()
174175

176+
numpy_output_ids = np.array([output_id.to_numpy()[0] for output_id in output_ids])
177+
print(tokenizer.decode(numpy_output_ids, skip_special_tokens=True))
178+
175179
print(
176180
f"total_time: {round((t2 - t1) * 1000, 2)} ms",
177181
)
@@ -208,13 +212,15 @@ def test(
208212
backend = args.backend
209213
tp = args.tp
210214

215+
if backend != "cpp":
216+
raise ValueError(f"Unsupported backend: {backend}.")
217+
211218
infini_device = infinicore.device(device_str, 0)
212219

213220
test(
214221
prompts,
215222
model_path,
216223
max_new_tokens,
217224
infini_device=infini_device,
218-
backend=backend,
219225
tp=tp,
220226
)

0 commit comments

Comments
 (0)