forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampling_params.py
More file actions
434 lines (401 loc) · 19.3 KB
/
sampling_params.py
File metadata and controls
434 lines (401 loc) · 19.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import random
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, List, Optional, TypeVar, Union
from fastdeploy import envs
T = TypeVar("T")
@dataclass
class SamplingParams:
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`. Warning, this is only supported in V0.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
tokens.
frequency_penalty: Float that penalizes new tokens based on their
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the prompt and the generated text so far. Values > 1
encourage the model to use new tokens, while values < 1 encourage
the model to repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: list of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
bad_words: list of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
can complete the sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
response_max_tokens: Maximum number of tokens to generate for response per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
When set to None, no probability is returned. If set to a non-None
value, the result includes the log probabilities of the specified
number of most likely tokens, as well as the chosen tokens.
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
"""
n: int = 1
best_of: Optional[int] = None
presence_penalty: float = None
frequency_penalty: float = None
repetition_penalty: float = None
temperature: float = None
top_p: float = None
top_k: int = 0
min_p: float = 0.0
sampling_threshold: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
stop_seqs_len: Optional[int] = None
max_tokens: Optional[int] = None
reasoning_max_tokens: Optional[int] = None
response_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[dict[str, Any]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
**{
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
}
)
@classmethod
def from_generic_request(cls, req: T) -> SamplingParams:
logprobs_val = None
if hasattr(req, "top_logprobs"):
if getattr(req, "logprobs", None):
logprobs_val = getattr(req, "top_logprobs", None)
else:
logprobs_val = getattr(req, "logprobs", None)
max_tokens_val = (
req.max_completion_tokens or getattr(req, "max_tokens", cls.max_tokens)
if hasattr(req, "max_completion_tokens")
else getattr(req, "max_tokens", cls.max_tokens)
)
return cls(
n=getattr(req, "n", None) if getattr(req, "n", None) is not None else cls.n,
best_of=getattr(req, "best_of", None) if getattr(req, "best_of", None) is not None else cls.best_of,
presence_penalty=(
getattr(req, "presence_penalty", None)
if getattr(req, "presence_penalty", None) is not None
else cls.presence_penalty
),
frequency_penalty=(
getattr(req, "frequency_penalty", None)
if getattr(req, "frequency_penalty", None) is not None
else cls.frequency_penalty
),
repetition_penalty=(
getattr(req, "repetition_penalty", None)
if getattr(req, "repetition_penalty", None) is not None
else cls.repetition_penalty
),
temperature=(
getattr(req, "temperature", None) if getattr(req, "temperature", None) is not None else cls.temperature
),
top_p=getattr(req, "top_p", None) if getattr(req, "top_p", None) is not None else cls.top_p,
top_k=getattr(req, "top_k", None) if getattr(req, "top_k", None) is not None else cls.top_k,
min_p=getattr(req, "min_p", None) if getattr(req, "min_p", None) is not None else cls.min_p,
sampling_threshold=(
getattr(req, "sampling_threshold", None)
if getattr(req, "sampling_threshold", None) is not None
else cls.sampling_threshold
),
seed=getattr(req, "seed", None) if getattr(req, "seed", None) is not None else cls.seed,
stop=getattr(req, "stop", None) if getattr(req, "stop", None) is not None else cls.stop,
stop_token_ids=(
getattr(req, "stop_token_ids", None)
if getattr(req, "stop_token_ids", None) is not None
else cls.stop_token_ids
),
stop_seqs_len=(
getattr(req, "stop_seqs_len", None)
if getattr(req, "stop_seqs_len", None) is not None
else cls.stop_seqs_len
),
max_tokens=max_tokens_val,
reasoning_max_tokens=(
getattr(req, "reasoning_max_tokens", None)
if getattr(req, "reasoning_max_tokens", None) is not None
else cls.reasoning_max_tokens
),
response_max_tokens=(
getattr(req, "response_max_tokens", None)
if getattr(req, "response_max_tokens", None) is not None
else cls.response_max_tokens
),
min_tokens=(
getattr(req, "min_tokens", None) if getattr(req, "min_tokens", None) is not None else cls.min_tokens
),
logprobs=logprobs_val,
prompt_logprobs=(
getattr(req, "prompt_logprobs", None)
if getattr(req, "prompt_logprobs", None) is not None
else cls.prompt_logprobs
),
temp_scaled_logprobs=(
getattr(req, "temp_scaled_logprobs", None)
if getattr(req, "temp_scaled_logprobs", None) is not None
else cls.temp_scaled_logprobs
),
top_p_normalized_logprobs=(
getattr(req, "top_p_normalized_logprobs", None)
if getattr(req, "top_p_normalized_logprobs", None) is not None
else cls.top_p_normalized_logprobs
),
bad_words=(
getattr(req, "bad_words", None) if getattr(req, "bad_words", None) is not None else cls.bad_words
),
guided_decoding=(
getattr(req, "guided_decoding", None)
if getattr(req, "guided_decoding", None) is not None
else cls.guided_decoding
),
bad_words_token_ids=(
getattr(req, "bad_words_token_ids", None)
if getattr(req, "bad_words_token_ids", None) is not None
else cls.bad_words_token_ids
),
logits_processors_args=(
getattr(req, "logits_processors_args", None)
if getattr(req, "logits_processors_args", None) is not None
else cls.logits_processors_args
),
)
@classmethod
def from_optional(
cls,
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
top_k,
min_p,
sampling_threshold=None,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
response_max_tokens=None,
min_tokens=1,
logprobs=None,
prompt_logprobs=None,
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
logits_processors_args=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=(presence_penalty if presence_penalty is not None else 0.0),
frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0),
repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0),
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
min_p=min_p if min_p is not None else 0.0,
sampling_threshold=sampling_threshold if sampling_threshold is not None else 0.0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
response_max_tokens=response_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
logits_processors_args=logits_processors_args,
)
def __post_init__(self):
if self.seed is None:
# Deterministic mode: use fixed seed
if envs.FD_DETERMINISTIC_MODE:
self.seed = 42
else:
self.seed = random.randint(0, 922337203685477580)
self._verify_args()
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if self.presence_penalty is not None and (not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.")
if self.frequency_penalty is not None and (not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.")
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError("repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}.")
if self.temperature is not None and self.temperature < 0.0:
raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0,1],got f{self.min_p}")
if not isinstance(self.sampling_threshold, float):
raise TypeError(f"sampling_threshold must be a float, got {type(self.sampling_threshold).__name__}")
if not 0.0 <= self.sampling_threshold < 1.0:
raise ValueError(f"sampling_threshold must be in [0.0, 1.0), got {self.sampling_threshold}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
self.reasoning_max_tokens = self.max_tokens
# response_max_tokens TODO
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, " f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
if self.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
else: # True (1)
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
# Verify logits processors arguments
if self.logits_processors_args is not None:
if self.logits_processors_args.get("logit_bias") is not None:
logit_bias = self.logits_processors_args.get("logit_bias")
if not isinstance(logit_bias, dict):
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
# try to cast the dict to the correct type first
try:
cast_logit_bias = {}
for k, v in logit_bias.items():
cast_logit_bias[int(k)] = float(v)
self.logits_processors_args["logit_bias"] = cast_logit_bias
except:
raise TypeError(
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
)
@dataclass
class BeamSearchParams:
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False
@dataclass
class GuidedDecodingParams:
"""Guided decoding parameters for text generation."""
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
structural_tag: Optional[str] = None
def to_dict(self):
"""convert to dict"""
key_dict = {
"guided_json": self.json,
"guided_regex": self.regex,
"guided_choice": self.choice,
"guided_grammar": self.grammar,
"structural_tag": self.structural_tag,
"guided_json_object": self.json_object,
}
guided_dict = {}
for key, value in key_dict.items():
if value is not None:
guided_dict[key] = value
return guided_dict
def __post_init__(self):
"""Verify the arguments."""
guided_count = sum(
[
self.json is not None,
self.regex is not None,
self.choice is not None,
self.grammar is not None,
self.json_object is not None,
self.structural_tag is not None,
]
)
if guided_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
)
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOutput
FINAL_ONLY = 2