-
Notifications
You must be signed in to change notification settings - Fork 478
Expand file tree
/
Copy pathgenerator.cc
More file actions
388 lines (344 loc) · 19.3 KB
/
generator.cc
File metadata and controls
388 lines (344 loc) · 19.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
#include "module.h"
#include <ctranslate2/generator.h>
#include "replica_pool.h"
namespace ctranslate2 {
namespace python {
class GeneratorWrapper : public ReplicaPoolHelper<Generator> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
_pool->for_each_replica([&](models::SequenceGeneratorReplica& replica) {
replica.set_alignment_heads(alignment_heads);
});
}
std::variant<std::vector<GenerationResult>,
std::vector<AsyncResult<GenerationResult>>>
generate_batch(const BatchTokens& tokens,
size_t max_batch_size,
const std::string& batch_type_str,
bool asynchronous,
size_t beam_size,
float patience,
size_t num_hypotheses,
float length_penalty,
float repetition_penalty,
size_t no_repeat_ngram_size,
bool disable_unk,
const std::optional<std::vector<std::vector<std::string>>>& suppress_sequences,
const std::optional<EndToken>& end_token,
bool return_end_token,
size_t max_length,
size_t min_length,
const std::optional<std::vector<std::string>>& static_prompt,
bool cache_static_prompt,
bool include_prompt_in_result,
bool return_scores,
bool return_attention,
bool return_logits_vocab,
bool return_alternatives,
float min_alternative_expansion_prob,
size_t sampling_topk,
float sampling_topp,
float sampling_temperature,
std::function<bool(GenerationStepResult)> callback) {
if (tokens.empty())
return {};
BatchType batch_type = str_to_batch_type(batch_type_str);
GenerationOptions options;
options.beam_size = beam_size;
options.patience = patience;
options.length_penalty = length_penalty;
options.repetition_penalty = repetition_penalty;
options.no_repeat_ngram_size = no_repeat_ngram_size;
options.disable_unk = disable_unk;
options.sampling_topk = sampling_topk;
options.sampling_topp = sampling_topp;
options.sampling_temperature = sampling_temperature;
options.max_length = max_length;
options.min_length = min_length;
options.num_hypotheses = num_hypotheses;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_attention = return_attention;
options.return_logits_vocab = return_logits_vocab;
options.return_alternatives = return_alternatives;
options.cache_static_prompt = cache_static_prompt;
options.include_prompt_in_result = include_prompt_in_result;
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
options.callback = std::move(callback);
if (suppress_sequences)
options.suppress_sequences = suppress_sequences.value();
if (end_token)
options.end_token = end_token.value();
if (static_prompt)
options.static_prompt = static_prompt.value();
std::shared_lock lock(_mutex);
assert_model_is_ready();
auto futures = _pool->generate_batch_async(tokens, options, max_batch_size, batch_type);
return maybe_wait_on_futures(std::move(futures), asynchronous);
}
std::variant<std::vector<ScoringResult>,
std::vector<AsyncResult<ScoringResult>>>
score_batch(const BatchTokens& tokens,
size_t max_batch_size,
const std::string& batch_type_str,
size_t max_input_length,
bool asynchronous) {
const auto batch_type = str_to_batch_type(batch_type_str);
ScoringOptions options;
options.max_input_length = max_input_length;
std::shared_lock lock(_mutex);
assert_model_is_ready();
auto futures = _pool->score_batch_async(tokens, options, max_batch_size, batch_type);
return maybe_wait_on_futures(std::move(futures), asynchronous);
}
StorageView
forward_batch(const std::variant<BatchTokens, BatchIds, StorageView>& inputs,
const std::optional<StorageView>& lengths,
const bool return_log_probs) {
std::future<StorageView> future;
switch (inputs.index()) {
case 0:
future = _pool->forward_batch_async(std::get<BatchTokens>(inputs), return_log_probs);
break;
case 1:
future = _pool->forward_batch_async(std::get<BatchIds>(inputs), return_log_probs);
break;
case 2:
if (!lengths)
throw std::invalid_argument("lengths vector is required when passing a dense input");
future = _pool->forward_batch_async(std::get<StorageView>(inputs),
lengths.value(),
return_log_probs);
break;
}
return future.get();
}
};
void register_generator(py::module& m) {
py::class_<GeneratorWrapper>(
m, "Generator",
R"pbdoc(
A text generator.
Example:
>>> generator = ctranslate2.Generator("model/", device="cpu")
>>> generator.generate_batch([["<s>"]], max_length=50, sampling_topk=20)
)pbdoc")
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("flash_attention")=false,
py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes the generator.
Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this generator on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Maximum number of parallel generations.
intra_threads: Number of OpenMP threads per generator (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
flash_attention: run model with flash attention 2 for self-attention layer
tensor_parallel: run model with tensor parallel mode.
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")
.def_property_readonly("device", &GeneratorWrapper::device,
"Device this generator is running on.")
.def_property_readonly("device_index", &GeneratorWrapper::device_index,
"List of device IDs where this generator is running on.")
.def_property_readonly("compute_type", &GeneratorWrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_generators", &GeneratorWrapper::num_replicas,
"Number of generators backing this instance.")
.def_property_readonly("num_queued_batches", &GeneratorWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("tensor_parallel", &GeneratorWrapper::tensor_parallel,
"Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
.def("set_alignment_heads", &GeneratorWrapper::set_alignment_heads,
py::arg("alignment_heads"),
R"pbdoc(
Configure which attention heads to collect when ``return_attention=True``.
By default, only head 0 of the last layer is returned (averaged).
Use this method to select specific (layer, head) pairs. The attention
from the selected heads will be concatenated in the output.
Arguments:
alignment_heads: List of (layer_index, head_index) pairs to collect.
Example:
>>> generator.set_alignment_heads([(31, 0), (31, 3), (33, 7)])
)pbdoc")
.def("generate_batch", &GeneratorWrapper::generate_batch,
py::arg("start_tokens"),
py::kw_only(),
py::arg("max_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("asynchronous")=false,
py::arg("beam_size")=1,
py::arg("patience")=1,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=1,
py::arg("repetition_penalty")=1,
py::arg("no_repeat_ngram_size")=0,
py::arg("disable_unk")=false,
py::arg("suppress_sequences")=py::none(),
py::arg("end_token")=py::none(),
py::arg("return_end_token")=false,
py::arg("max_length")=512,
py::arg("min_length")=0,
py::arg("static_prompt")=py::none(),
py::arg("cache_static_prompt")=true,
py::arg("include_prompt_in_result")=true,
py::arg("return_scores")=false,
py::arg("return_attention")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
py::arg("sampling_topk")=1,
py::arg("sampling_topp")=1,
py::arg("sampling_temperature")=1,
py::arg("callback")=nullptr,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Generates from a batch of start tokens.
Note:
The way the start tokens are forwarded in the decoder depends on the argument
:obj:`include_prompt_in_result`:
* If :obj:`include_prompt_in_result` is ``True`` (the default), the decoding loop
is constrained to generate the start tokens that are then included in the result.
* If :obj:`include_prompt_in_result` is ``False``, the start tokens are forwarded
in the decoder at once to initialize its state (i.e. the KV cache for
Transformer models). For variable-length inputs, only the tokens up to the
minimum length in the batch are forwarded at once. The remaining tokens are
generated in the decoding loop with constrained decoding.
Consider setting ``include_prompt_in_result=False`` to increase the performance
for long inputs.
Arguments:
start_tokens: Batch of start tokens. If the decoder starts from a special
start token like ``<s>``, this token should be added to this input.
max_batch_size: The maximum batch size. If the number of inputs is greater than :obj:`max_batch_size`,
the inputs are sorted by length and split by chunks of :obj:`max_batch_size` examples
(or tokens when :obj:`batch_type`="tokens") so that the number of padding positions
is minimized.
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
asynchronous: Run the generation asynchronously.
beam_size: Beam size (1 for greedy search).
patience: Beam search patience factor, as described in
https://arxiv.org/abs/2204.05424. The decoding will continue until
beam_size*patience hypotheses are finished.
num_hypotheses: Number of hypotheses to return.
length_penalty: Exponential penalty applied to the length during beam search.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size
(set 0 to disable).
disable_unk: Disable the generation of the unknown token.
suppress_sequences: Disable the generation of some sequences of tokens.
end_token: Stop the decoding on one of these tokens (defaults to the model EOS token).
return_end_token: Include the end token in the results.
max_length: Maximum generation length.
min_length: Minimum generation length.
static_prompt: If the model expects a static prompt (a.k.a. system prompt)
it can be set here to simplify the inputs and optionally cache the model
state for this prompt to accelerate future generations.
cache_static_prompt: Cache the model state after the static prompt and
reuse it for future generations using the same static prompt.
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
return_scores: Include the scores in the output.
return_attention: Include the attention matrices in the output.
return_logits_vocab: Include log probs for each token in the output
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
sampling_topk: Randomly sample predictions from the top K candidates.
sampling_topp: Keep the most probable tokens whose cumulative probability exceeds
this value.
sampling_temperature: Sampling temperature to generate more random samples.
callback: Optional function that is called for each generated token when
:obj:`beam_size` is 1. If the callback function returns ``True``, the
decoding will stop for this batch index.
Returns:
A list of generation results.
See Also:
`GenerationOptions <https://github.com/OpenNMT/CTranslate2/blob/master/include/ctranslate2/generation.h>`_ structure in the C++ library.
)pbdoc")
.def("score_batch", &GeneratorWrapper::score_batch,
py::arg("tokens"),
py::kw_only(),
py::arg("max_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("max_input_length")=1024,
py::arg("asynchronous")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Scores a batch of tokens.
Arguments:
tokens: Batch of tokens to score. If the model expects special start or end tokens,
they should also be added to this input.
max_batch_size: The maximum batch size. If the number of inputs is greater than
:obj:`max_batch_size`, the inputs are sorted by length and split by chunks of
:obj:`max_batch_size` examples so that the number of padding positions is
minimized.
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
max_input_length: Truncate inputs after this many tokens (0 to disable).
asynchronous: Run the scoring asynchronously.
Returns:
A list of scoring results.
)pbdoc")
.def("forward_batch", &GeneratorWrapper::forward_batch,
py::arg("inputs"),
py::arg("lengths")=py::none(),
py::kw_only(),
py::arg("return_log_probs")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Forwards a batch of sequences in the generator.
Arguments:
inputs: A batch of sequences either as string tokens or token IDs.
This argument can also be a dense int32 array with shape
``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor).
lengths: The length of each sequence as a int32 array with shape
``[batch_size]``. Required when :obj:`inputs` is a dense array.
return_log_probs: If ``True``, the method returns the log probabilties instead
of the unscaled logits.
Returns:
The output logits, or the output log probabilities if :obj:`return_log_probs`
is enabled.
)pbdoc")
.def("unload_model", &GeneratorWrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Unloads the model attached to this generator but keep enough runtime context
to quickly resume generator on the initial device. The model is not guaranteed
to be unloaded if generations are running concurrently.
Arguments:
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
)pbdoc")
.def("load_model", &GeneratorWrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Loads the model back to the initial device.
Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")
.def_property_readonly("model_is_loaded", &GeneratorWrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
;
}
}
}