-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathopenapi.py
More file actions
158 lines (111 loc) · 4.12 KB
/
Copy pathopenapi.py
File metadata and controls
158 lines (111 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from cozeloop.internal.httpclient import Client, BaseResponse
MPULL_PROMPT_PATH = "/v1/loop/prompts/mget"
MAX_PROMPT_QUERY_BATCH_SIZE = 25
class TemplateType(str, Enum):
NORMAL = "normal"
JINJA2 = "jinja2"
class Role(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
PLACEHOLDER = "placeholder"
class ToolType(str, Enum):
FUNCTION = "function"
class VariableType(str, Enum):
STRING = "string"
PLACEHOLDER = "placeholder"
BOOLEAN = "boolean"
INTEGER = "integer"
FLOAT = "float"
OBJECT = "object"
ARRAY_STRING = "array<string>"
ARRAY_BOOLEAN = "array<boolean>"
ARRAY_INTEGER = "array<integer>"
ARRAY_FLOAT = "array<float>"
ARRAY_OBJECT = "array<object>"
class ToolChoiceType(str, Enum):
AUTO = "auto"
NONE = "none"
class Message(BaseModel):
role: Role
content: Optional[str] = None
class VariableDef(BaseModel):
key: str
desc: str
type: VariableType
class Function(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[str] = None
class Tool(BaseModel):
type: ToolType
function: Optional[Function] = None
class ToolCallConfig(BaseModel):
tool_choice: ToolChoiceType
class LLMConfig(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
json_mode: Optional[bool] = None
class PromptTemplate(BaseModel):
template_type: TemplateType
messages: Optional[List[Message]] = None
variable_defs: Optional[List[VariableDef]] = None
class Prompt(BaseModel):
workspace_id: str
prompt_key: str
version: str
prompt_template: Optional[PromptTemplate] = None
tools: Optional[List[Tool]] = None
tool_call_config: Optional[ToolCallConfig] = None
llm_config: Optional[LLMConfig] = None
class PromptQuery(BaseModel):
prompt_key: str
version: Optional[str] = None
label: Optional[str] = None
class MPullPromptRequest(BaseModel):
workspace_id: str
queries: List[PromptQuery]
class PromptResult(BaseModel):
query: PromptQuery
prompt: Optional[Prompt] = None
class PromptResultData(BaseModel):
items: Optional[List[PromptResult]] = None
class MPullPromptResponse(BaseResponse):
data: Optional[PromptResultData] = None
class OpenAPIClient:
def __init__(self, http_client: Client):
self.http_client = http_client
def mpull_prompt(self, workspace_id: str, queries: List[PromptQuery]) -> List[PromptResult]:
sorted_queries = sorted(queries, key=lambda x: (x.prompt_key, x.version))
all_prompts = []
# If query count is less than or equal to the maximum batch size, execute directly
if len(sorted_queries) <= MAX_PROMPT_QUERY_BATCH_SIZE:
batch_results = self._do_mpull_prompt(workspace_id, sorted_queries)
if batch_results is not None:
all_prompts.extend(batch_results)
return all_prompts
# Process large number of queries in batches
for i in range(0, len(sorted_queries), MAX_PROMPT_QUERY_BATCH_SIZE):
batch_queries = sorted_queries[i:i + MAX_PROMPT_QUERY_BATCH_SIZE]
batch_results = self._do_mpull_prompt(workspace_id, batch_queries)
if batch_results is not None:
all_prompts.extend(batch_results)
return all_prompts
def _do_mpull_prompt(self, workspace_id: str, queries: List[PromptQuery]) -> Optional[List[PromptResult]]:
if not queries:
return None
request = MPullPromptRequest(workspace_id=workspace_id, queries=queries)
response = self.http_client.post(MPULL_PROMPT_PATH, MPullPromptResponse, request)
real_resp = MPullPromptResponse.model_validate(response)
if real_resp.data is not None:
return real_resp.data.items