-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__data_async_template.py
More file actions
65 lines (53 loc) · 1.96 KB
/
__data_async_template.py
File metadata and controls
65 lines (53 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Iterable, Optional, TypedDict
from openai.types.chat import ChatCompletionMessageParam
from typing_extensions import Required, Unpack
from agentrun.utils.config import Config
from agentrun.utils.data_api import DataAPI, ResourceType
class InvokeArgs(TypedDict):
messages: Required[Iterable[ChatCompletionMessageParam]]
stream: Required[bool]
config: Optional[Config]
class AgentRuntimeDataAPI(DataAPI):
def __init__(
self,
agent_runtime_name: str,
agent_runtime_endpoint_name: str = "Default",
config: Optional[Config] = None,
):
super().__init__(
resource_name=agent_runtime_name,
resource_type=ResourceType.Runtime,
namespace=f"agent-runtimes/{agent_runtime_name}/endpoints/{agent_runtime_endpoint_name}/invocations",
config=config,
)
async def invoke_openai_async(
self,
**kwargs: Unpack[InvokeArgs],
):
messages = kwargs.get("messages", [])
stream = kwargs.get("stream", False)
config = kwargs.get("config", None)
cfg = Config.with_configs(self.config, config)
api_base = self.with_path("openai/v1", config=cfg)
# Sign the actual request URL (OpenAI client will POST to base + /chat/completions)
chat_completions_url = api_base.rstrip("/") + "/chat/completions"
_, headers, _ = self.auth(
url=chat_completions_url,
headers=cfg.get_headers(),
config=cfg,
method="POST",
)
from httpx import AsyncClient
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key="",
base_url=api_base,
http_client=AsyncClient(headers=headers),
)
timeout = cfg.get_timeout()
return client.chat.completions.create(
model=self.resource_name,
messages=messages,
timeout=timeout,
stream=stream,
)