Skip to content

Commit 0687d6e

Browse files
authored
Merge pull request #8 from Zipstack/enhance-report-details
Added total cost and tokens for embedings and extraction LLM in detailed report
2 parents d05e5c4 + 58790de commit 0687d6e

1 file changed

Lines changed: 86 additions & 9 deletions

File tree

main.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sqlite3
66
import sys
77
import time
8+
import textwrap
89
from dataclasses import dataclass
910
from datetime import datetime
1011
from functools import partial
@@ -51,6 +52,10 @@ def init_db():
5152
time_taken REAL,
5253
status_code INTEGER,
5354
status_api_endpoint TEXT,
55+
total_embedding_cost REAL,
56+
total_embedding_tokens INTEGER,
57+
total_llm_cost REAL,
58+
total_llm_tokens INTEGER,
5459
updated_at TEXT,
5560
created_at TEXT
5661
)"""
@@ -97,6 +102,15 @@ def update_db(
97102
status_code,
98103
status_api_endpoint,
99104
):
105+
106+
total_embedding_cost = None
107+
total_embedding_tokens = None
108+
total_llm_cost = None
109+
total_llm_tokens = None
110+
111+
if result is not None:
112+
total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens = calculate_cost_and_tokens(result)
113+
100114
conn = sqlite3.connect(DB_NAME)
101115
conn.set_trace_callback(
102116
lambda x: (
@@ -109,16 +123,20 @@ def update_db(
109123
now = datetime.now().isoformat()
110124
c.execute(
111125
"""
112-
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, updated_at, created_at)
113-
VALUES (?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
114-
""",
126+
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
127+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
128+
""",
115129
(
116130
file_name,
117131
execution_status,
118132
json.dumps(result),
119133
time_taken,
120134
status_code,
121135
status_api_endpoint,
136+
total_embedding_cost,
137+
total_embedding_tokens,
138+
total_llm_cost,
139+
total_llm_tokens,
122140
now,
123141
file_name,
124142
now,
@@ -127,10 +145,57 @@ def update_db(
127145
conn.commit()
128146
conn.close()
129147

148+
# Calculate total cost and tokens for detailed report
149+
def calculate_cost_and_tokens(result):
150+
151+
total_embedding_cost = None
152+
total_embedding_tokens = None
153+
total_llm_cost = None
154+
total_llm_tokens = None
155+
156+
# Extract 'extraction_result' from the result
157+
extraction_result = result.get("extraction_result", [])
158+
159+
if not extraction_result:
160+
return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens
161+
162+
extraction_data = extraction_result[0].get("result", "")
163+
164+
# If extraction_data is a string, attempt to parse it as JSON
165+
if isinstance(extraction_data, str):
166+
try:
167+
extraction_data = json.loads(extraction_data) if extraction_data else {}
168+
except json.JSONDecodeError:
169+
logger.warning("Failed to decode JSON for extraction data; defaulting to empty dictionary.")
170+
extraction_data = {}
171+
172+
173+
metadata = extraction_data.get("metadata", None)
174+
embedding_llm = metadata.get("embedding") if metadata else None
175+
extraction_llm = metadata.get("extraction_llm") if metadata else None
176+
177+
#Process embedding costs and tokens if embedding_llm list exists and is not empty
178+
if embedding_llm:
179+
total_embedding_cost = 0.0
180+
total_embedding_tokens = 0
181+
for item in embedding_llm:
182+
total_embedding_cost += float(item.get("cost_in_dollars", "0"))
183+
total_embedding_tokens += item.get("embedding_tokens", 0)
184+
185+
#Process embedding costs and tokens if extraction_llm list exists and is not empty
186+
if extraction_llm:
187+
total_llm_cost = 0.0
188+
total_llm_tokens = 0
189+
for item in extraction_llm:
190+
total_llm_cost += float(item.get("cost_in_dollars", "0"))
191+
total_llm_tokens += item.get("total_tokens", 0)
192+
193+
return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens
194+
130195

131196
# Print final summary with count of each status and average time using a single SQL query
132197
def print_summary():
133-
conn = sqlite3.connect("file_processing.db")
198+
conn = sqlite3.connect(DB_NAME)
134199
c = conn.cursor()
135200

136201
# Fetch count and average time for each status
@@ -153,13 +218,13 @@ def print_summary():
153218

154219

155220
def print_report():
156-
conn = sqlite3.connect("file_processing.db")
221+
conn = sqlite3.connect(DB_NAME)
157222
c = conn.cursor()
158223

159-
# Fetch count and average time for each status
224+
# Fetch required fields, including total_cost and total_tokens
160225
c.execute(
161226
"""
162-
SELECT file_name, execution_status, time_taken
227+
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
163228
FROM file_status
164229
"""
165230
)
@@ -170,8 +235,20 @@ def print_report():
170235
print("\nDetailed Report:")
171236
if report_data:
172237
# Tabulate the data with column headers
173-
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)"]
174-
print(tabulate(report_data, headers=headers, tablefmt="pretty"))
238+
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]
239+
240+
# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
241+
formatted_data = []
242+
for row in report_data:
243+
formatted_row = [
244+
"None" if cell is None else
245+
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
246+
f"{cell:.8f}" if isinstance(cell, float) else cell
247+
for cell in row
248+
]
249+
formatted_data.append(formatted_row)
250+
251+
print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
175252
else:
176253
print("No records found in the database.")
177254

0 commit comments

Comments
 (0)