-
Notifications
You must be signed in to change notification settings - Fork 114
Expand file tree
/
Copy pathtable.py
More file actions
67 lines (55 loc) · 2.07 KB
/
table.py
File metadata and controls
67 lines (55 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# pytest: ollama, e2e
from io import StringIO
import pandas
import mellea
from mellea.stdlib.components.mify import mify
@mify(fields_include={"table"}, template="{{ table }}")
class MyCompanyDatabase:
table: str = """| Store | Sales |
| ---------- | ------- |
| Northeast | $250 |
| Southeast | $80 |
| Midwest | $420 |"""
def __init__(self, *, table: str | None = None):
if table is not None:
self.table = table
def update_sales(self, store: str, amount: str):
"""Update the sales for a specific store."""
table_df = pandas.read_csv(
StringIO(self.table),
sep="|",
skipinitialspace=True,
header=0,
index_col=False,
)
# Remove unnamed columns and columns that don't exist.
table_df = table_df.drop(table_df.filter(regex="Unname").columns, axis=1)
# Sometimes extra whitespace gets added to the column names and row values. Remove it.
table_df.columns = table_df.columns.str.strip()
table_df = table_df.map(lambda x: x.strip() if isinstance(x, str) else x)
table_df.loc[table_df["Store"] == store, "Sales"] = amount
self.table = table_df.to_csv(sep="|", index=False, header=True)
return self
def transpose(self):
"""Transpose the table."""
return (
pandas.read_csv(
StringIO(self.table),
sep="|",
skipinitialspace=True,
header=0,
index_col=False,
)
.transpose()
.to_csv(StringIO(), sep="|", index=False, header=True)
)
if __name__ == "__main__":
m = mellea.start_session()
db = MyCompanyDatabase()
print(m.query(db, "What were sales for the Northeast branch this month?").value)
result = m.transform(db, "Update the northeast sales to 1250.")
print(type(result))
print(db.table)
print(m.query(db, "What were sales for the Northeast branch this month?"))
result = m.transform(db, "Transpose the table.")
print(result)