Skip to content

Commit adc6aeb

Browse files
authored
Merge pull request #13 from Zipstack/feat/support-to-configure-db-path
feat: Added support to configure DB path, `result` in CSV report
2 parents ae62f8e + ee85994 commit adc6aeb

3 files changed

Lines changed: 62 additions & 49 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
*.db
2-
.venv/
2+
*.csv
3+
.mypy_cache/
4+
.venv/
5+
.python-version

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ This will display detailed usage information.
5959
- `-t`, `--api_timeout`: Timeout (in seconds) for API requests (default: 10).
6060
- `-i`, `--poll_interval`: Interval (in seconds) between API status polls (default: 5).
6161
- `-p`, `--parallel_call_count`: Number of parallel API calls (default: 10).
62+
- `--csv_report`: Path to export the detailed report as a CSV file.
63+
- `--db_path`: Path where the SQlite DB file is stored (default: './file_processing.db')
6264
- `--retry_failed`: Retry processing of failed files.
6365
- `--retry_pending`: Retry processing of pending files by making new requests.
6466
- `--skip_pending`: Skip processing of pending files.
@@ -67,7 +69,6 @@ This will display detailed usage information.
6769
- `--print_report`: Print a detailed report of all processed files at the end.
6870
- `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.
6971
- `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.)
70-
- `--csv_report`: Path to export the detailed report as a CSV file.
7172

7273
## Usage Examples
7374

main.py

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from tqdm import tqdm
1717
from unstract.api_deployments.client import APIDeploymentsClient
1818

19-
DB_NAME = "file_processing.db"
20-
global_arguments = None
2119
logger = logging.getLogger(__name__)
2220

2321

@@ -29,6 +27,7 @@ class Arguments:
2927
api_timeout: int = 10
3028
poll_interval: int = 5
3129
input_folder_path: str = ""
30+
db_path: str = ""
3231
parallel_call_count: int = 5
3332
retry_failed: bool = False
3433
retry_pending: bool = False
@@ -42,8 +41,8 @@ class Arguments:
4241

4342

4443
# Initialize SQLite DB
45-
def init_db():
46-
conn = sqlite3.connect(DB_NAME)
44+
def init_db(args: Arguments):
45+
conn = sqlite3.connect(args.db_path)
4746
c = conn.cursor()
4847

4948
# Create the table if it doesn't exist
@@ -89,7 +88,7 @@ def init_db():
8988

9089
# Check if the file is already processed
9190
def skip_file_processing(file_name, args: Arguments):
92-
conn = sqlite3.connect(DB_NAME)
91+
conn = sqlite3.connect(args.db_path)
9392
c = conn.cursor()
9493
c.execute(
9594
"SELECT execution_status FROM file_status WHERE file_name = ?", (file_name,)
@@ -124,6 +123,7 @@ def update_db(
124123
time_taken,
125124
status_code,
126125
status_api_endpoint,
126+
args: Arguments
127127
):
128128

129129
total_embedding_cost = None
@@ -138,7 +138,7 @@ def update_db(
138138
if execution_status == "ERROR":
139139
error_message = extract_error_message(result)
140140

141-
conn = sqlite3.connect(DB_NAME)
141+
conn = sqlite3.connect(args.db_path)
142142
conn.set_trace_callback(
143143
lambda x: (
144144
logger.debug(f"[{file_name}] Executing statement: {x}")
@@ -232,8 +232,8 @@ def extract_error_message(result):
232232
return result.get("error", "No error message found")
233233

234234
# Print final summary with count of each status and average time using a single SQL query
235-
def print_summary():
236-
conn = sqlite3.connect(DB_NAME)
235+
def print_summary(args: Arguments):
236+
conn = sqlite3.connect(args.db_path)
237237
c = conn.cursor()
238238

239239
# Fetch count and average time for each status
@@ -255,8 +255,8 @@ def print_summary():
255255
print(f"Status '{status}': {count}")
256256

257257

258-
def print_report():
259-
conn = sqlite3.connect(DB_NAME)
258+
def print_report(args: Arguments):
259+
conn = sqlite3.connect(args.db_path)
260260
c = conn.cursor()
261261

262262
# Fetch required fields, including total_cost and total_tokens
@@ -318,36 +318,36 @@ def print_report():
318318

319319
print("\nNote: For more detailed error messages, use the CSV report argument.")
320320

321-
def export_report_to_csv(output_path):
322-
conn = sqlite3.connect(DB_NAME)
321+
def export_report_to_csv(args: Arguments):
322+
conn = sqlite3.connect(args.db_path)
323323
c = conn.cursor()
324324

325325
c.execute(
326326
"""
327-
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
327+
SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
328328
FROM file_status
329329
"""
330330
)
331331
report_data = c.fetchall()
332332
conn.close()
333333

334334
if not report_data:
335-
print("No data available to export.")
335+
print("No data available to export as CSV.")
336336
return
337337

338338
# Define the headers
339339
headers = [
340-
"File Name", "Execution Status", "Time Elapsed (seconds)",
340+
"File Name", "Execution Status", "Result", "Time Elapsed (seconds)",
341341
"Total Embedding Cost", "Total Embedding Tokens",
342342
"Total LLM Cost", "Total LLM Tokens", "Error Message"
343343
]
344344

345345
try:
346-
with open(output_path, 'w', newline='') as csvfile:
346+
with open(args.csv_report, 'w', newline='') as csvfile:
347347
writer = csv.writer(csvfile)
348348
writer.writerow(headers) # Write headers
349349
writer.writerows(report_data) # Write data rows
350-
print(f"CSV successfully exported to {output_path}")
350+
print(f"CSV successfully exported to '{args.csv_report}'")
351351
except Exception as e:
352352
print(f"Error exporting to CSV: {e}")
353353

@@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
357357
status_endpoint = None
358358

359359
# If retry_pending is True, check if the status API endpoint is available
360-
conn = sqlite3.connect(DB_NAME)
360+
conn = sqlite3.connect(args.db_path)
361361
c = conn.cursor()
362362
c.execute(
363363
"SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')",
@@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
382382

383383
# Fresh API call to process the file
384384
execution_status = "STARTING"
385-
update_db(file_path, execution_status, None, None, None, None)
385+
update_db(file_path, execution_status, None, None, None, None, args=args)
386386
response = client.structure_file(file_paths=[file_path])
387387
logger.debug(f"[{file_path}] Response of initial API call: {response}")
388388
status_endpoint = response.get(
@@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
397397
None,
398398
status_code,
399399
status_endpoint,
400+
args=args
400401
)
401402
return status_endpoint, execution_status, response
402403

@@ -436,7 +437,7 @@ def process_file(
436437
execution_status = response.get("execution_status")
437438
status_code = response.get("status_code") # Default to 200 if not provided
438439
update_db(
439-
file_path, execution_status, None, None, status_code, status_endpoint
440+
file_path, execution_status, None, None, status_code, status_endpoint, args=args
440441
)
441442

442443
result = response
@@ -456,7 +457,7 @@ def process_file(
456457
end_time = time.time()
457458
time_taken = round(end_time - start_time, 2)
458459
update_db(
459-
file_path, execution_status, result, time_taken, status_code, status_endpoint
460+
file_path, execution_status, result, time_taken, status_code, status_endpoint, args=args
460461
)
461462
logger.info(f"[{file_path}]: Processing completed: {execution_status}")
462463

@@ -501,14 +502,14 @@ def load_folder(args: Arguments):
501502

502503

503504
def main():
504-
parser = argparse.ArgumentParser(description="Process files using the API.")
505+
parser = argparse.ArgumentParser(description="Process files using Unstract's API deployment")
505506
parser.add_argument(
506507
"-e",
507508
"--api_endpoint",
508509
dest="api_endpoint",
509510
type=str,
510511
required=True,
511-
help="API Endpoint to use for processing the files.",
512+
help="API Endpoint to use for processing the files",
512513
)
513514
parser.add_argument(
514515
"-k",
@@ -524,55 +525,68 @@ def main():
524525
dest="api_timeout",
525526
type=int,
526527
default=10,
527-
help="Time in seconds to wait before switching to async mode.",
528+
help="Time in seconds to wait before switching to async mode (default: 10)",
528529
)
529530
parser.add_argument(
530531
"-i",
531532
"--poll_interval",
532533
dest="poll_interval",
533534
type=int,
534535
default=5,
535-
help="Time in seconds the process will sleep between polls in async mode.",
536+
help="Time in seconds the process will sleep between polls in async mode (default: 5)",
536537
)
537538
parser.add_argument(
538539
"-f",
539540
"--input_folder_path",
540541
dest="input_folder_path",
541542
type=str,
542543
required=True,
543-
help="Path where the files to process are present.",
544+
help="Path where the files to process are present",
544545
)
545546
parser.add_argument(
546547
"-p",
547548
"--parallel_call_count",
548549
dest="parallel_call_count",
549550
type=int,
550551
default=5,
551-
help="Number of calls to be made in parallel.",
552+
help="Number of calls to be made in parallel (default: 5)",
553+
)
554+
parser.add_argument(
555+
"--db_path",
556+
dest="db_path",
557+
type=str,
558+
default="file_processing.db",
559+
help="Path where the SQlite DB file is stored (default: './file_processing.db)'",
560+
)
561+
parser.add_argument(
562+
'--csv_report',
563+
dest="csv_report",
564+
type=str,
565+
help='Path to export the detailed report as a CSV file',
552566
)
553567
parser.add_argument(
554568
"--retry_failed",
555569
dest="retry_failed",
556570
action="store_true",
557-
help="Retry processing of failed files.",
571+
help="Retry processing of failed files (default: True)",
558572
)
559573
parser.add_argument(
560574
"--retry_pending",
561575
dest="retry_pending",
562576
action="store_true",
563-
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API).",
577+
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API) (default: True)",
564578
)
565579
parser.add_argument(
566580
"--skip_pending",
567581
dest="skip_pending",
568582
action="store_true",
569-
help="Skip processing of pending files (Over rides --retry-pending).",
583+
help="Skip processing of pending files (overrides --retry-pending) (default: True)",
570584
)
571585
parser.add_argument(
572586
"--skip_unprocessed",
573587
dest="skip_unprocessed",
574588
action="store_true",
575-
help="Skip unprocessed files while retry processing of failed files.",
589+
help="Skip unprocessed files while retry processing of failed files (default: True)",
576590
)
577591
parser.add_argument(
578592
"--log_level",
@@ -586,52 +600,47 @@ def main():
586600
"--print_report",
587601
dest="print_report",
588602
action="store_true",
589-
help="Print a detailed report of all file processed.",
603+
help="Print a detailed report of all file processed (default: True)",
590604
)
591-
592605
parser.add_argument(
593606
"--exclude_metadata",
594607
dest="include_metadata",
595608
action="store_false",
596-
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.",
609+
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file (default: False)",
597610
)
598-
599611
parser.add_argument(
600612
"--no_verify",
601613
dest="verify",
602614
action="store_false",
603-
help="Disable SSL certificate verification.",
604-
)
605-
606-
parser.add_argument(
607-
'--csv_report',
608-
dest="csv_report",
609-
type=str,
610-
help='Path to export the detailed report as a CSV file',
615+
help="Disable SSL certificate verification (default: False)",
611616
)
612617

613618
args = Arguments(**vars(parser.parse_args()))
614619

615620
ch = logging.StreamHandler(sys.stdout)
616621
ch.setLevel(args.log_level)
622+
formatter = logging.Formatter(
623+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
624+
)
625+
ch.setFormatter(formatter)
617626
logging.basicConfig(level=args.log_level, handlers=[ch])
618627

619628
logger.warning(f"Running with params: {args}")
620629

621-
init_db() # Initialize DB
630+
init_db(args=args) # Initialize DB
622631

623632
load_folder(args=args)
624633

625-
print_summary() # Print summary at the end
634+
print_summary(args=args) # Print summary at the end
626635
if args.print_report:
627-
print_report()
636+
print_report(args=args)
628637
logger.warning(
629638
"Elapsed time calculation of a file which was resumed"
630639
" from pending state will not be correct"
631640
)
632641

633642
if args.csv_report:
634-
export_report_to_csv(args.csv_report)
643+
export_report_to_csv(args=args)
635644

636645

637646
if __name__ == "__main__":

0 commit comments

Comments
 (0)