11"""SQLite-specific utility functions for use in `graphnet.data`."""
22
3+ import os .path
4+ from typing import List
5+
36import pandas as pd
47import sqlalchemy
58import sqlite3
69
710
8- def run_sql_code (database : str , code : str ) -> None :
11+ def database_exists (database_path : str ) -> bool :
12+ """Check whether database exists at `database_path`."""
13+ assert database_path .endswith (
14+ ".db"
15+ ), "Provided database path does not end in `.db`."
16+ return os .path .exists (database_path )
17+
18+
19+ def database_table_exists (database_path : str , table_name : str ) -> bool :
20+ """Check whether `table_name` exists in database at `database_path`."""
21+ if not database_exists (database_path ):
22+ return False
23+ query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{ table_name } ';"
24+ with sqlite3 .connect (database_path ) as conn :
25+ result = pd .read_sql (query , conn )
26+ return len (result ) == 1
27+
28+
29+ def run_sql_code (database_path : str , code : str ) -> None :
930 """Execute SQLite code.
1031
1132 Args:
12- database : Path to databases
33+ database_path : Path to databases
1334 code: SQLite code
1435 """
15- conn = sqlite3 .connect (database )
36+ conn = sqlite3 .connect (database_path )
1637 c = conn .cursor ()
1738 c .executescript (code )
1839 c .close ()
1940
2041
21- def save_to_sql (df : pd .DataFrame , table_name : str , database : str ) -> None :
42+ def save_to_sql (df : pd .DataFrame , table_name : str , database_path : str ) -> None :
2243 """Save a dataframe `df` to a table `table_name` in SQLite `database`.
2344
2445 Table must exist already.
2546
2647 Args:
2748 df: Dataframe with data to be stored in sqlite table
2849 table_name: Name of table. Must exist already
29- database : Path to SQLite database
50+ database_path : Path to SQLite database
3051 """
31- engine = sqlalchemy .create_engine ("sqlite:///" + database )
52+ engine = sqlalchemy .create_engine ("sqlite:///" + database_path )
3253 df .to_sql (table_name , con = engine , index = False , if_exists = "append" )
3354 engine .dispose ()
3455
3556
36- def attach_index (database : str , table_name : str ) -> None :
37- """Attaches the table index.
57+ def attach_index (
58+ database_path : str , table_name : str , index_column : str = "event_no"
59+ ) -> None :
60+ """Attach the table (i.e., event) index.
3861
3962 Important for query times!
4063 """
4164 code = (
4265 "PRAGMA foreign_keys=off;\n "
4366 "BEGIN TRANSACTION;\n "
44- f"CREATE INDEX event_no_{ table_name } ON { table_name } (event_no);\n "
67+ f"CREATE INDEX { index_column } _{ table_name } "
68+ f"ON { table_name } ({ index_column } );\n "
4569 "COMMIT TRANSACTION;\n "
4670 "PRAGMA foreign_keys=on;"
4771 )
48- run_sql_code (database , code )
72+ run_sql_code (database_path , code )
4973
5074
5175def create_table (
52- df : pd . DataFrame ,
76+ columns : List [ str ] ,
5377 table_name : str ,
5478 database_path : str ,
55- is_pulse_map : bool = False ,
79+ * ,
80+ index_column : str = "event_no" ,
81+ default_type : str = "NOT NULL" ,
82+ integer_primary_key : bool = True ,
5683) -> None :
5784 """Create a table.
5885
5986 Args:
60- df: Data to be saved to table
87+ columns: Column names to be created in table.
6188 table_name: Name of the table.
6289 database_path: Path to the database.
63- is_pulse_map: Whether or not this is a pulse map table.
90+ index_column: Name of the index column.
91+ default_type: The type used for all non-index columns.
92+ integer_primary_key: Whether or not to create the `index_column` with
93+ the `INTEGER PRIMARY KEY` type. Such a column is required to have
94+ unique, integer values for each row. This is appropriate when the
95+ table has one row per event, e.g., event-level MC truth. It is not
96+ appropriate for pulse map series, particle-level MC truth, and
97+ other such data that is expected to have more that one row per
98+ event (i.e., with the same index).
6499 """
65- query_columns = list ()
66- for column in df .columns :
67- if column == "event_no" :
68- if not is_pulse_map :
100+ # Prepare column names and types
101+ query_columns = []
102+ for column in columns :
103+ type_ = default_type
104+ if column == index_column :
105+ if integer_primary_key :
69106 type_ = "INTEGER PRIMARY KEY NOT NULL"
70107 else :
71108 type_ = "NOT NULL"
72- else :
73- type_ = "NOT NULL"
109+
74110 query_columns .append (f"{ column } { type_ } " )
111+
75112 query_columns_string = ", " .join (query_columns )
76113
114+ # Run SQL code
77115 code = (
78116 "PRAGMA foreign_keys=off;\n "
79117 f"CREATE TABLE { table_name } ({ query_columns_string } );\n "
@@ -83,3 +121,29 @@ def create_table(
83121 database_path ,
84122 code ,
85123 )
124+
125+ # Attaching index to all non-truth-like tables (e.g., pulse maps).
126+ if not integer_primary_key :
127+ attach_index (database_path , table_name )
128+
129+
130+ def create_table_and_save_to_sql (
131+ df : pd .DataFrame ,
132+ table_name : str ,
133+ database_path : str ,
134+ * ,
135+ index_column : str = "event_no" ,
136+ default_type : str = "NOT NULL" ,
137+ integer_primary_key : bool = True ,
138+ ) -> None :
139+ """Create table if it doesn't exist and save dataframe to it."""
140+ if not database_table_exists (database_path , table_name ):
141+ create_table (
142+ df .columns ,
143+ table_name ,
144+ database_path ,
145+ index_column = index_column ,
146+ default_type = default_type ,
147+ integer_primary_key = integer_primary_key ,
148+ )
149+ save_to_sql (df , table_name = table_name , database_path = database_path )
0 commit comments