Skip to content

Commit 27769f8

Browse files
Strict data type (#1794)
* strict data type Signed-off-by: ZePan110 <ze.pan@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * strict data type Signed-off-by: ZePan110 <ze.pan@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: ZePan110 <ze.pan@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6ba45d3 commit 27769f8

2 files changed

Lines changed: 66 additions & 64 deletions

File tree

comps/cores/proto/api_protocol.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import shortuuid
99
from fastapi import File, Form, UploadFile
1010
from fastapi.responses import JSONResponse
11-
from pydantic import BaseModel, Field
11+
from pydantic import BaseModel, Field, NonNegativeFloat, PositiveInt
1212

1313

1414
class ServiceCard(BaseModel):
@@ -219,11 +219,11 @@ class RetrievalRequest(BaseModel):
219219
embedding: Union[EmbeddingResponse, List[float]] = None
220220
input: Optional[str] = None # search_type maybe need, like "mmr"
221221
search_type: str = "similarity"
222-
k: int = 4
222+
k: PositiveInt = 4
223223
distance_threshold: Optional[float] = None
224-
fetch_k: int = 20
225-
lambda_mult: float = 0.5
226-
score_threshold: float = 0.2
224+
fetch_k: PositiveInt = 20
225+
lambda_mult: NonNegativeFloat = 0.5
226+
score_threshold: NonNegativeFloat = 0.2
227227

228228
# define
229229
request_type: Literal["retrieval"] = "retrieval"
@@ -256,7 +256,7 @@ class RetrievalResponse(BaseModel):
256256
class RerankingRequest(BaseModel):
257257
input: str
258258
retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]], List[str]]
259-
top_n: int = 1
259+
top_n: PositiveInt = 1
260260

261261
# define
262262
request_type: Literal["reranking"] = "reranking"
@@ -285,17 +285,19 @@ class ChatCompletionRequest(BaseModel):
285285
logit_bias: Optional[Dict[str, float]] = None
286286
logprobs: Optional[bool] = False
287287
top_logprobs: Optional[int] = 0
288-
max_tokens: Optional[int] = 1024 # use https://platform.openai.com/docs/api-reference/completions/create
289-
n: Optional[int] = 1
288+
max_tokens: Optional[PositiveInt] = 1024 # use https://platform.openai.com/docs/api-reference/completions/create
289+
n: Optional[PositiveInt] = 1
290290
presence_penalty: Optional[float] = 0.0
291291
response_format: Optional[ResponseFormat] = None
292-
seed: Optional[int] = None
292+
seed: Optional[PositiveInt] = None
293293
service_tier: Optional[str] = None
294294
stop: Union[str, List[str], None] = Field(default_factory=list)
295295
stream: Optional[bool] = False
296296
stream_options: Optional[StreamOptions] = Field(default=None)
297-
temperature: Optional[float] = 0.01 # vllm default 0.7
298-
top_p: Optional[float] = None # openai default 1.0, but tgi needs `top_p` must be > 0.0 and < 1.0, set None
297+
temperature: Optional[NonNegativeFloat] = 0.01 # vllm default 0.7
298+
top_p: Optional[NonNegativeFloat] = (
299+
None # openai default 1.0, but tgi needs `top_p` must be > 0.0 and < 1.0, set None
300+
)
299301
tools: Optional[List[ChatCompletionToolsParam]] = None
300302
tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none"
301303
parallel_tool_calls: Optional[bool] = True
@@ -307,22 +309,22 @@ class ChatCompletionRequest(BaseModel):
307309
# Ordered by official OpenAI API documentation
308310
# default values are same with
309311
# https://platform.openai.com/docs/api-reference/completions/create
310-
best_of: Optional[int] = 1
312+
best_of: Optional[PositiveInt] = 1
311313
suffix: Optional[str] = None
312314

313315
# vllm reference: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L130
314-
repetition_penalty: Optional[float] = 1.0
316+
repetition_penalty: Optional[NonNegativeFloat] = 1.0
315317

316318
# tgi reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate
317319
# some tgi parameters in use
318320
# default values are same with
319321
# https://github.com/huggingface/text-generation-inference/blob/main/router/src/lib.rs#L190
320322
# max_new_tokens: Optional[int] = 100 # Priority use openai
321-
top_k: Optional[int] = None
323+
top_k: Optional[PositiveInt] = None
322324
# top_p: Optional[float] = None # Priority use openai
323325
typical_p: Optional[float] = None
324326
# repetition_penalty: Optional[float] = None
325-
timeout: Optional[int] = None
327+
timeout: Optional[PositiveInt] = None
326328

327329
# doc: begin-chat-completion-extra-params
328330
echo: Optional[bool] = Field(
@@ -382,16 +384,16 @@ class ChatCompletionRequest(BaseModel):
382384

383385
# retrieval
384386
search_type: str = "similarity"
385-
k: int = 4
387+
k: PositiveInt = 4
386388
distance_threshold: Optional[float] = None
387-
fetch_k: int = 20
388-
lambda_mult: float = 0.5
389-
score_threshold: float = 0.2
389+
fetch_k: PositiveInt = 20
390+
lambda_mult: NonNegativeFloat = 0.5
391+
score_threshold: NonNegativeFloat = 0.2
390392
retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]]] = Field(default_factory=list)
391393
index_name: Optional[str] = None
392394

393395
# reranking
394-
top_n: int = 1
396+
top_n: PositiveInt = 1
395397
reranked_docs: Union[List[RerankingResponseData], List[Dict[str, Any]]] = Field(default_factory=list)
396398

397399
# define
@@ -416,16 +418,16 @@ class AudioChatCompletionRequest(BaseModel):
416418
]
417419
] = None
418420
model: Optional[str] = "Intel/neural-chat-7b-v3-3"
419-
temperature: Optional[float] = 0.01
420-
top_p: Optional[float] = 0.95
421-
top_k: Optional[int] = 10
422-
n: Optional[int] = 1
423-
max_tokens: Optional[int] = 1024
421+
temperature: Optional[NonNegativeFloat] = 0.01
422+
top_p: Optional[NonNegativeFloat] = 0.95
423+
top_k: Optional[PositiveInt] = 10
424+
n: Optional[PositiveInt] = 1
425+
max_tokens: Optional[PositiveInt] = 1024
424426
stop: Optional[Union[str, List[str]]] = None
425427
stream: Optional[bool] = False
426-
presence_penalty: Optional[float] = 0.0
427-
frequency_penalty: Optional[float] = 0.0
428-
repetition_penalty: Optional[float] = 1.03
428+
presence_penalty: Optional[NonNegativeFloat] = 0.0
429+
frequency_penalty: Optional[NonNegativeFloat] = 0.0
430+
repetition_penalty: Optional[NonNegativeFloat] = 1.03
429431
user: Optional[str] = None
430432

431433

@@ -458,7 +460,7 @@ class AudioSpeechRequest(BaseModel):
458460
model: Optional[str] = "microsoft/speecht5_tts"
459461
voice: Optional[str] = "default"
460462
response_format: Optional[str] = "mp3"
461-
speed: Optional[float] = 1.0
463+
speed: Optional[NonNegativeFloat] = 1.0
462464

463465

464466
class ChatMessage(BaseModel):
@@ -506,18 +508,18 @@ class CompletionRequest(BaseModel):
506508
model: str
507509
prompt: Union[str, List[Any]]
508510
suffix: Optional[str] = None
509-
temperature: Optional[float] = 0.7
510-
n: Optional[int] = 1
511-
max_tokens: Optional[int] = 16
511+
temperature: Optional[NonNegativeFloat] = 0.7
512+
n: Optional[PositiveInt] = 1
513+
max_tokens: Optional[PositiveInt] = 16
512514
stop: Optional[Union[str, List[str]]] = None
513515
stream: Optional[bool] = False
514-
top_p: Optional[float] = 1.0
516+
top_p: Optional[NonNegativeFloat] = 1.0
515517
top_k: Optional[int] = -1
516518
logprobs: Optional[int] = None
517519
echo: Optional[bool] = False
518-
presence_penalty: Optional[float] = 0.0
519-
frequency_penalty: Optional[float] = 0.0
520-
repetition_penalty: Optional[float] = 1.03
520+
presence_penalty: Optional[NonNegativeFloat] = 0.0
521+
frequency_penalty: Optional[NonNegativeFloat] = 0.0
522+
repetition_penalty: Optional[NonNegativeFloat] = 1.03
521523
user: Optional[str] = None
522524
use_beam_search: Optional[bool] = False
523525
best_of: Optional[int] = None
@@ -915,7 +917,7 @@ class FineTuningJobListRequest(BaseModel):
915917
after: Optional[str] = None
916918
"""Identifier for the last job from the previous pagination request."""
917919

918-
limit: Optional[int] = 20
920+
limit: Optional[PositiveInt] = 20
919921
"""Number of fine-tuning jobs to retrieve."""
920922

921923

comps/cores/proto/docarray.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from docarray import BaseDoc, DocList
88
from docarray.documents import AudioDoc
99
from docarray.typing import AudioUrl, ImageUrl
10-
from pydantic import Field, conint, conlist, field_validator
10+
from pydantic import Field, NonNegativeFloat, PositiveInt, conint, conlist, field_validator
1111

1212

1313
class TopologyInfo:
@@ -96,11 +96,11 @@ class EmbedDoc(BaseDoc):
9696
text: Union[str, List[str]]
9797
embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]]
9898
search_type: str = "similarity"
99-
k: int = 4
99+
k: PositiveInt = 4
100100
distance_threshold: Optional[float] = None
101-
fetch_k: int = 20
102-
lambda_mult: float = 0.5
103-
score_threshold: float = 0.2
101+
fetch_k: PositiveInt = 20
102+
lambda_mult: NonNegativeFloat = 0.5
103+
score_threshold: NonNegativeFloat = 0.2
104104
constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None
105105
index_name: Optional[str] = None
106106

@@ -135,7 +135,7 @@ class Audio2TextDoc(AudioDoc):
135135
class SearchedDoc(BaseDoc):
136136
retrieved_docs: DocList[TextDoc]
137137
initial_query: str
138-
top_n: int = 1
138+
top_n: PositiveInt = 1
139139

140140
class Config:
141141
json_encoders = {np.ndarray: lambda x: x.tolist()}
@@ -177,14 +177,14 @@ class LLMParamsDoc(BaseDoc):
177177
model: Optional[str] = None # for openai and ollama
178178
query: str
179179
max_tokens: int = 1024
180-
max_new_tokens: int = 1024
181-
top_k: int = 10
182-
top_p: float = 0.95
183-
typical_p: float = 0.95
184-
temperature: float = 0.01
185-
frequency_penalty: float = 0.0
186-
presence_penalty: float = 0.0
187-
repetition_penalty: float = 1.03
180+
max_new_tokens: PositiveInt = 1024
181+
top_k: PositiveInt = 10
182+
top_p: NonNegativeFloat = 0.95
183+
typical_p: NonNegativeFloat = 0.95
184+
temperature: NonNegativeFloat = 0.01
185+
frequency_penalty: NonNegativeFloat = 0.0
186+
presence_penalty: NonNegativeFloat = 0.0
187+
repetition_penalty: NonNegativeFloat = 1.03
188188
stream: bool = True
189189
language: str = "auto" # can be "en", "zh"
190190

@@ -216,14 +216,14 @@ def chat_template_must_contain_variables(cls, v):
216216
class LLMParams(BaseDoc):
217217
model: Optional[str] = None
218218
max_tokens: int = 1024
219-
max_new_tokens: int = 1024
220-
top_k: int = 10
221-
top_p: float = 0.95
222-
typical_p: float = 0.95
223-
temperature: float = 0.01
224-
frequency_penalty: float = 0.0
225-
presence_penalty: float = 0.0
226-
repetition_penalty: float = 1.03
219+
max_new_tokens: PositiveInt = 1024
220+
top_k: PositiveInt = 10
221+
top_p: NonNegativeFloat = 0.95
222+
typical_p: NonNegativeFloat = 0.95
223+
temperature: NonNegativeFloat = 0.01
224+
frequency_penalty: NonNegativeFloat = 0.0
225+
presence_penalty: NonNegativeFloat = 0.0
226+
repetition_penalty: NonNegativeFloat = 1.03
227227
stream: bool = True
228228
language: str = "auto" # can be "en", "zh"
229229
index_name: Optional[str] = None
@@ -241,15 +241,15 @@ class LLMParams(BaseDoc):
241241

242242
class RetrieverParms(BaseDoc):
243243
search_type: str = "similarity"
244-
k: int = 4
244+
k: PositiveInt = 4
245245
distance_threshold: Optional[float] = None
246-
fetch_k: int = 20
247-
lambda_mult: float = 0.5
248-
score_threshold: float = 0.2
246+
fetch_k: PositiveInt = 20
247+
lambda_mult: NonNegativeFloat = 0.5
248+
score_threshold: NonNegativeFloat = 0.2
249249

250250

251251
class RerankerParms(BaseDoc):
252-
top_n: int = 1
252+
top_n: PositiveInt = 1
253253

254254

255255
class RAGASParams(BaseDoc):

0 commit comments

Comments
 (0)