-
Notifications
You must be signed in to change notification settings - Fork 475
Expand file tree
/
Copy pathdecoding_utils.h
More file actions
244 lines (205 loc) · 7.85 KB
/
decoding_utils.h
File metadata and controls
244 lines (205 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#pragma once
#include <algorithm>
#include <limits>
#include "ops/tile.h"
#include "storage_view.h"
namespace ctranslate2 {
inline void split_batch_beam(StorageView& input, dim_t beam_size) {
Shape shape = input.shape();
shape.insert(shape.begin() + 1, beam_size);
shape[0] /= beam_size;
input.reshape(std::move(shape));
}
inline void merge_batch_beam(StorageView& input) {
Shape shape = input.shape();
shape[0] *= shape[1];
shape.erase(shape.begin() + 1);
input.reshape(std::move(shape));
}
inline void repeat_batch(StorageView& input, dim_t repeats) {
input.expand_dims(1);
ops::Tile(/*axis=*/1, repeats)(input);
merge_batch_beam(input);
}
inline bool is_eos(const size_t id, const std::vector<size_t>& end_ids) {
return std::find(end_ids.begin(), end_ids.end(), id) != end_ids.end();
}
// Helper class to disable tokens in the model output.
class DisableTokens {
public:
DisableTokens(StorageView& logits,
const float disable_value = std::numeric_limits<float>::lowest());
void add(dim_t batch_id, dim_t token_id) {
const auto flat_index = batch_id * _vocabulary_size + token_id;
if (_logits_data) {
// On CPU we directly assign the value.
_logits_data[flat_index] = _disable_value;
} else {
// On GPU we prepare a list of unique index to disable.
const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index);
if (it == _flat_indices.end() || *it != flat_index)
_flat_indices.insert(it, flat_index);
}
}
// Disable a token for all batches.
void add(dim_t token_id) {
for (dim_t batch_id = 0; batch_id < _batch_size; ++batch_id)
add(batch_id, token_id);
}
void apply();
private:
StorageView& _logits;
float* _logits_data;
const float _disable_value;
const dim_t _batch_size;
const dim_t _vocabulary_size;
std::vector<int32_t> _flat_indices;
};
// Helper class to bias tokens in the model output.
class BiasTokens {
public:
BiasTokens(StorageView& logits);
void add(dim_t batch_id, dim_t token_id, float bias_value) {
const auto flat_index = batch_id * _vocabulary_size + token_id;
if (_logits_data) {
// On CPU w directly assign the biased value.
_logits_data[flat_index] = _logits_data[flat_index] * bias_value;
} else {
// On GPU we prepare a list of unique indices and values to bias.
const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index,
[](const auto& a, const auto& b) { return a.first < b; });
if (it == _flat_indices.end() || it->first != flat_index) {
_flat_indices.emplace(it, flat_index, bias_value);
} else {
it->second *= bias_value;
}
}
}
// Bias a token for all batches.
void add(dim_t token_id, float bias_value) {
for (dim_t batch_id = 0; batch_id < _batch_size; ++batch_id)
add(batch_id, token_id, bias_value);
}
void apply();
private:
StorageView& _logits;
float* _logits_data;
const dim_t _batch_size;
const dim_t _vocabulary_size;
std::vector<std::pair<int32_t, float>> _flat_indices;
};
// Base class for processing the output logits.
class LogitsProcessor {
public:
virtual ~LogitsProcessor() = default;
virtual bool apply_first() const {
return false;
}
virtual void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) = 0;
protected:
dim_t get_batch_index(const dim_t batch_size,
const dim_t batch_id,
const std::vector<dim_t>& batch_offset) const {
const auto beam_size = batch_size / batch_offset.size();
return batch_offset[batch_id / beam_size];
}
dim_t get_sample_begin(const dim_t batch_size,
const dim_t batch_id,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) const {
return prefix ? prefix->at(get_batch_index(batch_size, batch_id, batch_offset)).size() : 0;
}
};
// Apply a penalty to the score of previously generated tokens.
class RepetitionPenalty : public LogitsProcessor {
public:
RepetitionPenalty(const float penalty);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
const float _penalty;
};
// Prevent repetitions of ngrans with a specific size.
class NoRepeatNgram : public LogitsProcessor {
public:
NoRepeatNgram(const size_t ngram_size);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
const dim_t _ngram_size;
};
// Disable the generation of some sequences of tokens.
class SuppressSequences : public LogitsProcessor {
public:
SuppressSequences(std::vector<std::vector<size_t>> sequences);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
std::vector<size_t> _ids;
std::vector<std::vector<size_t>> _sequences;
};
// Bias towards the generation of some sequences of tokens.
class BiasSequences : public LogitsProcessor {
public:
BiasSequences(std::vector<std::pair<std::vector<size_t>, float>> sequences);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
std::vector<std::pair<size_t, float>> _ids;
std::vector<std::pair<std::vector<size_t>, float>> _sequences;
};
// Disable the generation of some tokens.
class SuppressTokens : public LogitsProcessor {
public:
SuppressTokens(std::vector<size_t> ids);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
const std::vector<size_t> _ids;
};
// Disable the generation of some tokens at the first unconstrained decoding step.
class SuppressTokensBegin : public LogitsProcessor {
public:
SuppressTokensBegin(std::vector<size_t> ids);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
private:
const std::vector<size_t> _ids;
};
}