-
Notifications
You must be signed in to change notification settings - Fork 475
Expand file tree
/
Copy pathdecoding.h
More file actions
176 lines (156 loc) · 5.9 KB
/
decoding.h
File metadata and controls
176 lines (156 loc) · 5.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#pragma once
#include <functional>
#include <optional>
#include "ctranslate2/decoding_utils.h"
#include "ctranslate2/devices.h"
#include "ctranslate2/layers/decoder.h"
#include "ctranslate2/sampling.h"
#include "ctranslate2/storage_view.h"
namespace ctranslate2 {
struct DecodingResult {
std::vector<std::vector<size_t>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits_vocab;
};
struct DecodingStepResult {
size_t step;
size_t batch_id;
size_t token_id;
size_t hypothesis_id;
std::optional<float> score;
std::optional<StorageView> logits;
bool is_last = false;
};
class SearchStrategy {
public:
virtual ~SearchStrategy() = default;
virtual std::vector<DecodingResult>
search(layers::Decoder& decoder,
layers::DecoderState& state,
const Sampler& sampler,
const std::vector<size_t>& start_ids,
const std::vector<size_t>& end_ids,
const dim_t start_step,
const dim_t max_length,
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const = 0;
};
class BeamSearch : public SearchStrategy {
public:
BeamSearch(const dim_t beam_size,
const float length_penalty = 0,
const float coverage_penalty = 0,
const float prefix_bias_beta = 0,
const float patience = 1);
std::vector<DecodingResult>
search(layers::Decoder& decoder,
layers::DecoderState& state,
const Sampler& sampler,
const std::vector<size_t>& start_ids,
const std::vector<size_t>& end_ids,
const dim_t start_step,
const dim_t max_length,
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;
private:
const dim_t _beam_size;
const float _length_penalty;
const float _coverage_penalty;
const float _prefix_bias_beta;
const size_t _max_candidates;
};
class BiasedDecoder {
public:
BiasedDecoder(const float prefix_bias_beta,
const std::vector<std::vector<size_t>>& prefix_ids);
void
decode(const dim_t cur_batch_size,
const size_t step,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<bool>>& beams_diverged_from_prefix,
const StorageView& logits,
StorageView& log_probs);
private:
StorageView _spare_beam;
const float _prefix_bias_beta;
std::vector<std::vector<size_t>> _prefix_ids;
};
class GreedySearch : public SearchStrategy {
public:
// Penalties are only applied to return scores consistent with the beam search.
GreedySearch(const float length_penalty = 0,
const float coverage_penalty = 0,
std::function<bool(DecodingStepResult)> callback = nullptr);
std::vector<DecodingResult>
search(layers::Decoder& decoder,
layers::DecoderState& state,
const Sampler& sampler,
const std::vector<size_t>& start_ids,
const std::vector<size_t>& end_id,
const dim_t start_step,
const dim_t max_length,
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;
private:
const float _length_penalty;
const float _coverage_penalty;
const std::function<bool(DecodingStepResult)> _callback;
};
struct DecodingOptions {
size_t beam_size = 1;
float patience = 1;
float length_penalty = 0;
float coverage_penalty = 0;
float repetition_penalty = 1;
size_t no_repeat_ngram_size = 0;
float prefix_bias_beta = 0;
dim_t start_step = 0;
size_t max_length = 256;
size_t min_length = 0;
size_t sampling_topk = 1;
float sampling_topp = 1;
float sampling_temperature = 1;
size_t num_hypotheses = 1;
bool include_eos_in_hypotheses = true;
bool return_scores = false;
bool return_attention = false;
bool return_logits_vocab = false;
bool return_alternatives = false;
bool return_prefix = true;
float min_alternative_expansion_prob = 0;
std::vector<size_t> disable_ids;
std::vector<size_t> disable_ids_begin;
std::vector<std::vector<size_t>> disable_sequences;
std::vector<std::pair<std::vector<size_t>, float>> sequence_bias;
std::vector<std::shared_ptr<LogitsProcessor>> logits_processors;
std::function<bool(DecodingStepResult)> callback = nullptr;
};
std::vector<DecodingResult>
decode(layers::Decoder& decoder,
layers::DecoderState& state,
std::vector<std::vector<size_t>> start_tokens,
std::vector<size_t> end_ids,
DecodingOptions options = DecodingOptions());
}