Skip to content

Commit f234c9f

Browse files
authored
Merge pull request #36 from Serverless-Devs/feat-embedding-api
refactor(model): integrate ModelAPI into model proxy and service classes
2 parents 36550e9 + 70876de commit f234c9f

File tree

7 files changed

+324
-159
lines changed

7 files changed

+324
-159
lines changed

agentrun/model/__model_proxy_async_template.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pydash
1010

1111
from agentrun.model.api.data import BaseInfo, ModelDataAPI
12+
from agentrun.model.api.model_api import ModelAPI
1213
from agentrun.utils.config import Config
1314
from agentrun.utils.model import Status
1415
from agentrun.utils.resource import ResourceBase
@@ -30,6 +31,7 @@ class ModelProxy(
3031
ModelProxyImmutableProps,
3132
ModelProxyMutableProps,
3233
ModelProxySystemProps,
34+
ModelAPI,
3335
ResourceBase,
3436
):
3537
"""模型服务"""
@@ -230,41 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
230232
)
231233

232234
return self._data_client.model_info()
233-
234-
def completions(
235-
self,
236-
messages: list,
237-
model: Optional[str] = None,
238-
stream: bool = False,
239-
config: Optional[Config] = None,
240-
**kwargs,
241-
):
242-
self.model_info(config)
243-
assert self._data_client
244-
245-
return self._data_client.completions(
246-
**kwargs,
247-
messages=messages,
248-
model=model,
249-
stream=stream,
250-
config=config,
251-
)
252-
253-
def responses(
254-
self,
255-
messages: list,
256-
model: Optional[str] = None,
257-
stream: bool = False,
258-
config: Optional[Config] = None,
259-
**kwargs,
260-
):
261-
self.model_info(config)
262-
assert self._data_client
263-
264-
return self._data_client.responses(
265-
**kwargs,
266-
messages=messages,
267-
model=model,
268-
stream=stream,
269-
config=config,
270-
)

agentrun/model/__model_service_async_template.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from typing import List, Optional
88

9-
from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
9+
from agentrun.model.api.data import BaseInfo
10+
from agentrun.model.api.model_api import ModelAPI
1011
from agentrun.utils.config import Config
1112
from agentrun.utils.model import PageableInput
1213
from agentrun.utils.resource import ResourceBase
@@ -27,6 +28,7 @@ class ModelService(
2728
ModelServiceImmutableProps,
2829
ModelServiceMutableProps,
2930
ModelServicesSystemProps,
31+
ModelAPI,
3032
ResourceBase,
3133
):
3234
"""模型服务"""
@@ -230,38 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
230232
model=default_model,
231233
headers=cfg.get_headers(),
232234
)
233-
234-
def completions(
235-
self,
236-
messages: list,
237-
model: Optional[str] = None,
238-
stream: bool = False,
239-
**kwargs,
240-
):
241-
info = self.model_info(config=kwargs.get("config"))
242-
243-
m = ModelCompletionAPI(
244-
api_key=info.api_key or "",
245-
base_url=info.base_url or "",
246-
model=model or info.model or self.model_service_name or "",
247-
)
248-
249-
return m.completions(**kwargs, messages=messages, stream=stream)
250-
251-
def responses(
252-
self,
253-
messages: list,
254-
model: Optional[str] = None,
255-
stream: bool = False,
256-
**kwargs,
257-
):
258-
info = self.model_info(config=kwargs.get("config"))
259-
260-
m = ModelCompletionAPI(
261-
api_key=info.api_key or "",
262-
base_url=info.base_url or "",
263-
model=model or info.model or self.model_service_name or "",
264-
provider=(self.provider or "openai").lower(),
265-
)
266-
267-
return m.responses(**kwargs, messages=messages, stream=stream)

agentrun/model/api/model_api.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, TYPE_CHECKING, Union
3+
4+
from .data import BaseInfo
5+
6+
if TYPE_CHECKING:
7+
from litellm import ResponseInputParam
8+
9+
10+
class ModelAPI(ABC):
11+
12+
@abstractmethod
13+
def model_info(self) -> BaseInfo:
14+
...
15+
16+
def completions(
17+
self,
18+
**kwargs,
19+
):
20+
"""
21+
Deprecated. Use completion() instead.
22+
"""
23+
import warnings
24+
25+
warnings.warn(
26+
"completions() is deprecated, use completion() instead",
27+
DeprecationWarning,
28+
stacklevel=2,
29+
)
30+
return self.completion(**kwargs)
31+
32+
def completion(
33+
self,
34+
messages=[],
35+
model: Optional[str] = None,
36+
custom_llm_provider: Optional[str] = None,
37+
**kwargs,
38+
):
39+
from litellm import completion
40+
41+
info = self.model_info()
42+
return completion(
43+
**kwargs,
44+
api_key=info.api_key,
45+
base_url=info.base_url,
46+
model=model or info.model or "",
47+
custom_llm_provider=custom_llm_provider
48+
or info.provider
49+
or "openai",
50+
messages=messages,
51+
)
52+
53+
async def acompletion(
54+
self,
55+
messages=[],
56+
model: Optional[str] = None,
57+
custom_llm_provider: Optional[str] = None,
58+
**kwargs,
59+
):
60+
from litellm import acompletion
61+
62+
info = self.model_info()
63+
return await acompletion(
64+
**kwargs,
65+
api_key=info.api_key,
66+
base_url=info.base_url,
67+
model=model or info.model or "",
68+
custom_llm_provider=custom_llm_provider
69+
or info.provider
70+
or "openai",
71+
messages=messages,
72+
)
73+
74+
def responses(
75+
self,
76+
input: Union[str, "ResponseInputParam"],
77+
model: Optional[str] = None,
78+
custom_llm_provider: Optional[str] = None,
79+
**kwargs,
80+
):
81+
from litellm import responses
82+
83+
info = self.model_info()
84+
return responses(
85+
**kwargs,
86+
api_key=info.api_key,
87+
base_url=info.base_url,
88+
model=model or info.model or "",
89+
custom_llm_provider=custom_llm_provider
90+
or info.provider
91+
or "openai",
92+
input=input,
93+
)
94+
95+
async def aresponses(
96+
self,
97+
input: Union[str, "ResponseInputParam"],
98+
model: Optional[str] = None,
99+
custom_llm_provider: Optional[str] = None,
100+
**kwargs,
101+
):
102+
from litellm import aresponses
103+
104+
info = self.model_info()
105+
return await aresponses(
106+
**kwargs,
107+
api_key=info.api_key,
108+
base_url=info.base_url,
109+
model=model or info.model or "",
110+
custom_llm_provider=custom_llm_provider
111+
or info.provider
112+
or "openai",
113+
input=input,
114+
)
115+
116+
def embedding(
117+
self,
118+
input=[],
119+
model: Optional[str] = None,
120+
custom_llm_provider: Optional[str] = None,
121+
**kwargs,
122+
):
123+
from litellm import embedding
124+
125+
info = self.model_info()
126+
return embedding(
127+
**kwargs,
128+
api_key=info.api_key,
129+
api_base=info.base_url,
130+
model=model or info.model or "",
131+
custom_llm_provider=custom_llm_provider
132+
or info.provider
133+
or "openai",
134+
input=input,
135+
)
136+
137+
def aembedding(
138+
self,
139+
input=[],
140+
model: Optional[str] = None,
141+
custom_llm_provider: Optional[str] = None,
142+
**kwargs,
143+
):
144+
from litellm import aembedding
145+
146+
info = self.model_info()
147+
return aembedding(
148+
**kwargs,
149+
api_key=info.api_key,
150+
api_base=info.base_url,
151+
model=model or info.model or "",
152+
custom_llm_provider=custom_llm_provider
153+
or info.provider
154+
or "openai",
155+
input=input,
156+
)

agentrun/model/model_proxy.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pydash
2020

2121
from agentrun.model.api.data import BaseInfo, ModelDataAPI
22+
from agentrun.model.api.model_api import ModelAPI
2223
from agentrun.utils.config import Config
2324
from agentrun.utils.model import Status
2425
from agentrun.utils.resource import ResourceBase
@@ -40,6 +41,7 @@ class ModelProxy(
4041
ModelProxyImmutableProps,
4142
ModelProxyMutableProps,
4243
ModelProxySystemProps,
44+
ModelAPI,
4345
ResourceBase,
4446
):
4547
"""模型服务"""
@@ -399,41 +401,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
399401
)
400402

401403
return self._data_client.model_info()
402-
403-
def completions(
404-
self,
405-
messages: list,
406-
model: Optional[str] = None,
407-
stream: bool = False,
408-
config: Optional[Config] = None,
409-
**kwargs,
410-
):
411-
self.model_info(config)
412-
assert self._data_client
413-
414-
return self._data_client.completions(
415-
**kwargs,
416-
messages=messages,
417-
model=model,
418-
stream=stream,
419-
config=config,
420-
)
421-
422-
def responses(
423-
self,
424-
messages: list,
425-
model: Optional[str] = None,
426-
stream: bool = False,
427-
config: Optional[Config] = None,
428-
**kwargs,
429-
):
430-
self.model_info(config)
431-
assert self._data_client
432-
433-
return self._data_client.responses(
434-
**kwargs,
435-
messages=messages,
436-
model=model,
437-
stream=stream,
438-
config=config,
439-
)

agentrun/model/model_service.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from typing import List, Optional
1818

19-
from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
19+
from agentrun.model.api.data import BaseInfo
20+
from agentrun.model.api.model_api import ModelAPI
2021
from agentrun.utils.config import Config
2122
from agentrun.utils.model import PageableInput
2223
from agentrun.utils.resource import ResourceBase
@@ -37,6 +38,7 @@ class ModelService(
3738
ModelServiceImmutableProps,
3839
ModelServiceMutableProps,
3940
ModelServicesSystemProps,
41+
ModelAPI,
4042
ResourceBase,
4143
):
4244
"""模型服务"""
@@ -401,38 +403,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
401403
model=default_model,
402404
headers=cfg.get_headers(),
403405
)
404-
405-
def completions(
406-
self,
407-
messages: list,
408-
model: Optional[str] = None,
409-
stream: bool = False,
410-
**kwargs,
411-
):
412-
info = self.model_info(config=kwargs.get("config"))
413-
414-
m = ModelCompletionAPI(
415-
api_key=info.api_key or "",
416-
base_url=info.base_url or "",
417-
model=model or info.model or self.model_service_name or "",
418-
)
419-
420-
return m.completions(**kwargs, messages=messages, stream=stream)
421-
422-
def responses(
423-
self,
424-
messages: list,
425-
model: Optional[str] = None,
426-
stream: bool = False,
427-
**kwargs,
428-
):
429-
info = self.model_info(config=kwargs.get("config"))
430-
431-
m = ModelCompletionAPI(
432-
api_key=info.api_key or "",
433-
base_url=info.base_url or "",
434-
model=model or info.model or self.model_service_name or "",
435-
provider=(self.provider or "openai").lower(),
436-
)
437-
438-
return m.responses(**kwargs, messages=messages, stream=stream)

0 commit comments

Comments
 (0)