-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
179 lines (145 loc) · 5.68 KB
/
app.py
File metadata and controls
179 lines (145 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from smolagents import CodeAgent, HfApiModel,load_tool,tool, LiteLLMModel
import datetime
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
import pandas as pd
import duckdb
from pydantic import BaseModel, Field
from Gradio_UI import GradioUI
SQL_GENERATION_PROMPT = """
Generate an SQL query based on a prompt. Do not reply with anything besides the SQL query.
Keep the dates in the results.
The prompt is {prompt}
The columns of the table are {columns}
The table name is {table_name}
"""
def generate_sql_query(prompt:str, columns:list, table_name:str) -> str:
"""A tool that generates an SQL query based on the input prompt.
Args:
prompt: the original prompt
columns: the columns of the table
table_name: the name of the table
"""
formatted_prompt = SQL_GENERATION_PROMPT.format(prompt=prompt, columns=columns, table_name=table_name)
messages = [{"role": "user", "content": formatted_prompt}]
response = model(messages, stop_sequences=["END"])
return response.content
TRANSACTION_DATA_FILE_PATH = 'data/Store_Sales_Price_Elasticity_Promotions_Data.parquet'
@tool
def lookup_sales_data(prompt:str) -> str:
"""A tool that looks up sales data based on the input prompt.
Args:
prompt: the original prompt
"""
try:
table_name = "sales"
df = pd.read_parquet(TRANSACTION_DATA_FILE_PATH)
duckdb.sql(f"CREATE TABLE {table_name} AS SELECT * FROM df")
sql_query = generate_sql_query(prompt, df.columns.to_list(), table_name)
sql_query = sql_query.strip()
sql_query = sql_query.replace("```sql", "").replace("```", "")
result = duckdb.sql(sql_query).df()
return result.to_string()
except Exception as e:
return f"Error reading data file: {str(e)}"
DATA_ANALYSIS_PROMPT = """
Analyze the following data: {data}
Your job is to answer the following question: {prompt}
"""
@tool
def analyze_sales_data(prompt: str, data: str) -> str:
"""Analyze sales data to extract insights according to what is asked in the prompt
Args:
prompt: The unchanged prompt that the user provided intially. Keep the initial prompt
data: The lookup_sales_data tool's output.
"""
formatted_prompt = DATA_ANALYSIS_PROMPT.format(data=data, prompt=prompt)
messages = [{"role": "user", "content": formatted_prompt}]
response = model(messages, stop_sequences=["END"])
analysis = response.content
return analysis if analysis else "No analysis could be generated"
CHART_CONFIGURATION_PROMPT = """
Generate a chart configuration based on this data: {data}
The goal is to show: {visualization_goal}
"""
class VisualizationConfig(BaseModel):
chart_type: str = Field(..., description="Type of chart to generate")
x_axis: str = Field(..., description="Name of the x-axis column")
y_axis: str = Field(..., description="Name of the y-axis column")
title: str = Field(..., description="Title of the chart")
def extract_chart_config(data: str, visualization_goal: str) -> dict:
"""Generate chart visualization configuration
Args:
data: String containing the data to visualize
visualization_goal: Description of what the visualization should show
Returns:
Dictionary containing line chart configuration
"""
formatted_prompt = CHART_CONFIGURATION_PROMPT.format(data=data,
visualization_goal=visualization_goal)
messages = [{"role": "user", "content": formatted_prompt}]
response = model(messages, stop_sequences=["END"], response_format=VisualizationConfig)
try:
# Extract axis and title info from response
content = response.content
# Return structured chart config
return {
"chart_type": content.chart_type,
"x_axis": content.x_axis,
"y_axis": content.y_axis,
"title": content.title,
"data": data
}
except Exception:
return {
"chart_type": "line",
"x_axis": "date",
"y_axis": "value",
"title": visualization_goal,
"data": data
}
CREATE_CHART_PROMPT = """
Write python code to create a chart based on the following configuration.
Only return the code, no other text.
config: {config}
"""
def create_chart(config: dict) -> str:
"""Create a chart based on the configuration"""
formatted_prompt = CREATE_CHART_PROMPT.format(config=config)
messages = [{"role": "user", "content": formatted_prompt}]
response = model(messages, stop_sequences=["END"],)
code = response.content
code = code.replace("```python", "").replace("```", "")
code = code.strip()
return code
@tool
def generate_visualization(data: str, visualization_goal: str) -> str:
"""Generate a visualization based on the data and goal
Args:
data: String containing the data to visualize
visualization_goal: Description of what the visualization should show"""
config = extract_chart_config(data, visualization_goal)
code = create_chart(config)
return code
final_answer = FinalAnswerTool()
model = LiteLLMModel(
model_id='anthropic/claude-3-5-sonnet-latest',
temperature=0.5,
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, lookup_sales_data, analyze_sales_data, generate_visualization],
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch()