-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathexecute.py
More file actions
169 lines (146 loc) · 5.98 KB
/
execute.py
File metadata and controls
169 lines (146 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""This module contains the AgentExecute class that is responsible for executing the SQL code and returning the output."""
import re
import logging
from typing import Optional
from collections.abc import AsyncIterable
from semantic_kernel.kernel import Kernel
from semantic_kernel.agents import ChatCompletionAgent
from semantic_kernel.connectors.ai.prompt_execution_settings import (
PromptExecutionSettings,
)
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.text_content import TextContent
from src.mysql.execution_env import SqlEnv
from src.utils.constants import Constants
logger: logging.Logger = logging.getLogger(__name__)
class SQLExecuteAgent(ChatCompletionAgent):
"""
Agent base implementation for executing SQL code and returning the output.
"""
env: Optional[SqlEnv] = None
def __init__(
self,
service_id: str | None = None,
kernel: Kernel | None = None,
name: str | None = None,
id: str | None = None,
description: str | None = None,
instructions: str | None = None,
execution_settings: PromptExecutionSettings | None = None,
):
"""
Initialize the SQLExecuteAgent.
Args:
service_id (str | None): The service id of the agent.
kernel (Kernel | None): The kernel instance.
name (str | None): The name of the agent.
id (str | None): The id of the agent.
description (str | None): The description of the agent.
instructions (str | None): The instructions for the agent.
execution_settings (PromptExecutionSettings | None): The execution settings for the agent.
"""
super().__init__(
service_id=service_id,
kernel=kernel,
name=name,
id=id,
description=description,
instructions=instructions,
execution_settings=execution_settings,
)
def sql_parser_react(self, action: str) -> tuple[str, bool]:
"""
Parse the SQL code from the action.
Args:
action (str): The action to parse.
Returns:
tuple[str, bool]: The parsed action and a boolean indicating if the action is valid.
"""
if action == Constants.action_submit:
return action, True
pattern = r"execute\[(.*)\]"
matches = re.findall(pattern, action, re.DOTALL)
if len(matches) > 0:
action = matches[0]
if ";" in action:
return action[: action.index(";")], True
return action, True
return action, False
async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]:
"""
Execute the SQL code and return the output.
Args:
history (ChatHistory): The chat history.
Yields:
AsyncIterable[ChatMessageContent]: The output message.
"""
chat = self._setup_agent_chat_history(history)
message = chat[-1].content
logger.info(
"[%s] Invoked %s with message count: %d.",
type(self).__name__,
"code_executor",
len(chat),
)
a = message.strip().split(f"{Constants.action_identifier} ")
try:
_, action = a[0], a[1]
action_parsed, is_code = self.sql_parser_react(action)
except IndexError:
action_parsed, is_code = None, False
if not is_code:
observation = f"{Constants.sql_error_message}: Your last `execute` action did not contain SQL code"
if Constants.sql_show_database in action_parsed:
observation = f"{Constants.sql_error_message}: SHOW DATABASES is not allowed in this environment."
else:
# Security Guardrail 02: Check for SQL Data Manipulation related keywords
if any(
keyword.lower() + " " in action_parsed.lower()
for keyword in Constants.sql_data_manipulation_commands
):
observation = f"{Constants.sql_error_message}: SQL Data Manipulation Language (DML) is not allowed in this environment."
else:
observation, _, _, _ = self.env.step(action_parsed)
# Limit observation size due to context window thresholds for API call
if isinstance(observation, str) and len(observation) > 350:
observation = observation[:350]
elif isinstance(observation, list) and len(observation) > 25:
observation = observation[:25]
code_output = f"{Constants.observation_identifier}{observation}"
output_message = ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[TextContent(text=code_output)],
name=self.name,
)
messages = [output_message]
history.add_message(output_message)
for message in messages:
yield message
class AgentExecute:
"""
Agent that executes the SQL code and returns the output.
"""
name = "executor"
def __init__(self, sql_executor_env: SqlEnv, kernel: Kernel | None = None):
"""
Initialize the AgentExecute.
Args:
sql_executor_env (SqlEnv): The SQL execution environment.
kernel (Kernel | None): The kernel instance.
"""
self.agent = SQLExecuteAgent(
name=self.name,
service_id=self.name,
kernel=kernel,
description="A Code Executor Agent that executes the SQL query and returns the output.",
)
self.agent.env = sql_executor_env
def get_agent(self) -> SQLExecuteAgent:
"""
Get the AgentExecute instance.
Returns:
SQLExecuteAgent: The AgentExecute instance.
"""
return self.agent