forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext_token_generator.h
More file actions
209 lines (179 loc) · 6.72 KB
/
text_token_generator.h
File metadata and controls
209 lines (179 loc) · 6.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// Generate tokens in a loop.
#pragma once
#include <atomic>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/tensor/tensor.h>
#include <pytorch/tokenizers/tokenizer.h>
namespace executorch {
namespace extension {
namespace llm {
class ET_EXPERIMENTAL TextTokenGenerator {
public:
TextTokenGenerator(
::tokenizers::Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
Stats* stats)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
eos_ids_(std::move(eos_ids)),
use_kv_cache_(use_kv_cache),
stats_(stats) {}
void set_ignore_eos(bool ignore_eos) {
ignore_eos_ = ignore_eos;
}
virtual ~TextTokenGenerator() = default;
/**
* Token generation loop.
* @param tokens The first token generated by prefill, if using kv cache. Else
* the prompt tokens + the first token generated by prefill.
* @param start_pos The start position of the new tokens, based on how many
* prompt tokens is prefilled.
* @param max_new_tokens Maximum number of new tokens to generate.
* @param temperature controls the randomness of predictions by scaling the
* logits before applying softmax. A higher temperature results in more
* random predictions, while a lower temperature results in more deterministic
* predictions.
* @param token_callback what to do after a token is generated.
* @return how many tokens are generated.
*/
inline ::executorch::runtime::Result<int64_t> generate(
std::vector<uint64_t> tokens,
int64_t start_pos,
int32_t max_new_tokens,
float temperature = 0.0f,
const std::function<void(const std::string&)>& token_callback = {}) {
ET_CHECK_MSG(
!tokens.empty(), "Token generation loop shouldn't take empty tokens");
int64_t pos = start_pos; // position in the sequence
std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<executorch::aten::SizesType> token_shape;
// Token after prefill
uint64_t cur_token = tokens.back();
uint64_t prev_token;
if (use_kv_cache_) {
// hard code these to size 1 as kv cache is locked to static size right
// now.
token_data = {cur_token};
token_shape = {1, 1};
} else {
token_data = tokens;
token_shape = {1, static_cast<int>(tokens.size())};
// Prevent reallocation that would invalidate from_blob's data pointer.
token_data.reserve(token_data.size() + max_new_tokens);
}
// Create tensor wrapper. For non-kv-cache, use max capacity shape so
// numel_bound_ is large enough for subsequent resize_tensor_ptr calls,
// then resize down to the actual token count.
auto max_shape = use_kv_cache_
? token_shape
: std::vector<executorch::aten::SizesType>{
1, static_cast<int>(tokens.size() + max_new_tokens)};
auto tokens_managed = from_blob(
token_data.data(), max_shape, executorch::aten::ScalarType::Long);
if (!use_kv_cache_) {
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor_ptr(tokens_managed, token_shape));
}
should_stop_.store(false, std::memory_order_relaxed);
// Generate our tokens
while (pos < start_pos + max_new_tokens) {
// Run the model
auto logits_res = text_decoder_runner_->step(tokens_managed, pos);
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
executorch::aten::Tensor& logits_tensor = logits_res.get();
prev_token = cur_token;
stats_->on_sampling_begin();
cur_token =
text_decoder_runner_->logits_to_token(logits_tensor, temperature);
stats_->on_sampling_end();
pos++;
if (use_kv_cache_) {
// update the token tensor. token_data will not be empty.
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
token_data[0] = cur_token;
} else {
// push it to the back
token_data.push_back(cur_token);
ET_CHECK_OK_OR_RETURN_ERROR(resize_tensor_ptr(
tokens_managed, {1, static_cast<int>(token_data.size())}));
}
// print the token as string, decode it with the Tokenizer object
auto decode_result = tokenizer_->decode(prev_token, cur_token);
if (!decode_result.ok()) {
ET_LOG(
Error,
"Tokenizers error code %d",
static_cast<uint32_t>(decode_result.error()));
return ::executorch::runtime::Error::InvalidArgument;
}
token_callback(std::move(*decode_result));
if (should_stop_.load(std::memory_order_relaxed)) {
break;
}
// data-dependent terminating condition: we have n_eos_ number of EOS
if (!ignore_eos_ && eos_ids_->find(cur_token) != eos_ids_->end()) {
printf("\n");
ET_LOG(Info, "\nReached to the end of generation");
break;
}
}
return pos - start_pos;
}
/**
* Stop the generation loop.
*/
inline void stop() {
should_stop_.store(true, std::memory_order_relaxed);
}
/**
* Load the necessary resources for TextTokenGenerator.
* This method should be called before using the generate() method.
*/
::executorch::runtime::Error load() {
return text_decoder_runner_->load();
}
/**
* Check if the TextTokenGenerator has been successfully loaded.
* @return True if the resources are loaded, false otherwise.
*/
bool inline is_loaded() const {
// Implementation to check if resources are loaded
return tokenizer_->is_loaded() && text_decoder_runner_->is_method_loaded();
}
private:
/**
* Note: TextTokenGenerator does not own the tokenizer_ and
* text_decoder_runner_. The lifecycle of these objects should be managed
* externally, likely in the Runner. This class assumes that the provided
* pointers remain valid for the duration of its use.
*/
::tokenizers::Tokenizer* tokenizer_;
TextDecoderRunner* text_decoder_runner_;
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
bool use_kv_cache_;
bool ignore_eos_ = false;
// state machine
std::atomic<bool> should_stop_{false};
// stats
Stats* stats_;
};
} // namespace llm
} // namespace extension
} // namespace executorch
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::TextTokenGenerator;
} // namespace executor
} // namespace torch