-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTools
More file actions
180 lines (150 loc) · 9.48 KB
/
Tools
File metadata and controls
180 lines (150 loc) · 9.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import urllib.parse
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from sqlalchemy import create_engine
from langchain.agents import create_openai_tools_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
import os
import urllib.parse
from pydantic import BaseModel, Field
from sqlalchemy.exc import SAWarning
import warnings
from langchain_openai import ChatOpenAI
from langchain.chains import APIChain
from langchain.tools import StructuredTool
from langchain_community.utilities import SQLDatabase
from langchain.agents import AgentExecutor
from API_DOCS import API_Movie_DOC
load_dotenv()
warnings.filterwarnings("ignore", category=SAWarning)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
LLM = ChatOpenAI(model="gpt-4o", temperature=0)
class QueryingPostgresSchema(BaseModel):
query: str = Field(..., description="Query to run on the PostgreSQL database")
def QueryingPostgresTool(query: str):
db_config = {
"dbname": os.environ.get("DB_NAME_V2"),
"user": os.environ.get("DB_USER_V2"),
"password": os.environ.get("DB_PASSWORD_V2"),
"host": os.environ.get("DB_HOST_V2", "localhost"),
"port": os.environ.get("DB_PORT", "5432")
}
connection_string = (
f"postgresql+psycopg2://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
)
engine = create_engine(connection_string)
try:
db = SQLDatabase(engine)
except Exception as e:
error_message = f"Error connecting to the database: {e}"
print(error_message)
raise SystemExit(error_message)
toolkit = SQLDatabaseToolkit(db=db, llm=LLM)
context = toolkit.get_context()
tools = toolkit.get_tools()
SQL_FUNCTIONS_SUFFIX = """
You are provided with the following tools for interacting with a SQL database. Each tool serves a specific purpose, and it's important to know when and how to use them correctly. Below is a description of each tool, along with guidelines on when to use each tool.
1. **ListSQLDatabaseTool (sql_db_list_tables)**:
- **Description**: This tool retrieves a comma-separated list of all table names from the database. It does not take any input and simply returns the full list of tables.
- **When to use**: Use this tool at the beginning of a query session to list all available tables in the database. This helps you understand what data is available in the database and helps identify which tables might be relevant to the user's query.
2. **InfoSQLDatabaseTool (sql_db_schema)**:
- **Description**: This tool retrieves the schema (i.e., column names and data types) for specific tables in the database. You must provide a comma-separated list of table names to this tool. It returns the schema for the specified tables along with some sample rows of data.
- **When to use**: After identifying which tables might be relevant to a user's query, use this tool to get detailed schema information for those specific tables. It is **crucial** that you carefully identify all tables that could possibly be relevant to the user's query. Be **very accurate** in this step, ensuring that no potentially relevant table is overlooked. Include every possible relevant table, even if there is only a slight chance that it may contain useful data for the query. Missing a relevant table at this step could result in incomplete or inaccurate query results.
3. **QuerySQLDataBaseTool (sql_db_query)**:
- **Description**: This tool executes a SQL query against the database and returns the result. You must provide a syntactically correct SQL query as input. If the query is incorrect or fails, it will return an error message instead of throwing an exception.
- **When to use**: Use this tool when you are ready to execute a SQL query and retrieve actual data from the database. Always make sure the query is correct before using this tool to avoid errors. If an error occurs, you should revise the query and try again.
4. **QuerySQLCheckerTool (sql_db_query_checker)**:
- **Description**: This tool uses a language model (LLM) to check the correctness of a SQL query before executing it. You provide the SQL query as input, and the tool will analyze it to ensure it is valid for execution. It checks for syntax errors, dialect issues, and correctness.
- **When to use**: Before executing any SQL query using the `sql_db_query` tool, use this tool to validate the query. This is especially useful when you're unsure if the query is correct or if the query might result in an error. It acts as a safety check before running the actual query.
---
### Guidelines for Use:
- **Start with `sql_db_list_tables`**: Always begin by listing all available tables in the database. This will give you an overview of the data sources you can query.
- **Be very accurate with `sql_db_schema`**: Once you have identified which tables might be relevant to the user’s query, use `sql_db_schema` to fetch the schema of those tables. In this step, it is essential to be **very precise** when deciding which tables are relevant. Include **all possible relevant tables** and do not ignore any that might contain useful data. Ensure that the query will have access to all relevant information, even from tables that might seem only marginally related to the query.
- **When generating SQL queries, always incorporate flexible search patterns (e.g., LIKE, ILIKE, SOUNDEX, or wildcard %), especially when users ask for data but are unsure of the exact formatting, casing, or content of the data. This ensures that searches are not limited by precise matches, and instead capture a wider range of potential results. For text-based columns, use ILIKE or LIKE with the % wildcard for partial matching. If dealing with phonetic similarity, consider using SOUNDEX for better matches based on sound.**
- **Check the query with `sql_db_query_checker`**: Before executing the final SQL query, run it through the `sql_db_query_checker` to ensure it is correct and valid.
- **Execute the query with `sql_db_query`**: After confirming the query is correct, use `sql_db_query` to retrieve the data and provide the final result to the user.
Always ensure that the SQL query is crafted based on the schema information, and double-check for errors using the query checker tool before execution.
"""
messages = [
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
prompt = prompt.partial(**context)
llm_agent = ChatOpenAI(model="gpt-4-turbo", temperature=0)
agent = create_openai_tools_agent(llm_agent, tools, prompt)
agent_executor = AgentExecutor(
agent=agent,
tools=toolkit.get_tools(),
verbose=True,
max_iterations=50,
)
response = agent_executor.invoke({"input": query})
# fine_tuned_model = "ft:gpt-3.5-turbo-0125:arkleapp::9UBqNNUb"
# fine_tuned_model = "ft:gpt-3.5-turbo-0125:inno::A9999bfVr"
# fine_tuned_model = "ft:gpt-4o-2024-08-06:inno::ADXYooqTR"
# db_chain = SQLDatabaseChain.from_llm(
# llm=ChatOpenAI(temperature=0, model=fine_tuned_model, openai_api_key=get_env_variable("OPENAI_API_KEY")), db=db,
# verbose=False)
# QUERY_TEMPLATE = """
# Given an input question, create a syntactically correct PostgreSQL query to run. Ensure to use the correct table name by verifying against the database schema if necessary.
# If the initial query does not return results, try alternative table names or check for potential typos.
# Use the following format:
#
# Question: {question}
# SQLQuery: SQL Query to run
# SQLResult: Result of the SQLQuery
# Answer: Final answer here
# """
#
# try:
# # question = QUERY_TEMPLATE.format(question=query[:100])
# question = QUERY_TEMPLATE.format(question=query)
#
# result = db_chain.invoke(question)
# # result = db_chain.invoke({"query": query})
#
# return result
# except Exception as e:
# return f"An error occurred: {e}"
return response
def get_querying_postgres_tool():
return StructuredTool(
name="QueryingPostgresTool",
description="Use this tool for retrieving data from PostgreSQL.",
func=QueryingPostgresTool,
args_schema=QueryingPostgresSchema
)
class MoviesApiToolSchema(BaseModel):
query: str = Field(..., description="Query about movies")
def MoviesApiTool(query: str):
base_url = "https://api.themoviedb.org/3"
encoded_query = urllib.parse.quote(query)
url = f"{base_url}/search/movie?api_key=your_api_key&query={encoded_query}&include_adult=false&language=en-US"
chain = APIChain.from_llm_and_api_docs(
LLM,
doc=url,
api_docs=API_Movie_DOC,
headers={},
verbose=True,
limit_to_domains=["https://api.themoviedb.org"]
)
try:
answer = chain.run(query)
print('Result:', answer)
except Exception as e:
print('Error:', e)
def get_custom_movies_api_tool():
return StructuredTool(
name="Movies",
description="Use this tool to answer questions about movies",
func=MoviesApiTool,
args_schema=MoviesApiToolSchema
)