Skip to content

Commit f64a29a

Browse files
authored
Merge pull request #366 from asogaard/sqlite-utility-methods
SQLite utility methods
2 parents f12f7ee + 6aef937 commit f64a29a

9 files changed

Lines changed: 249 additions & 223 deletions

File tree

src/graphnet/data/pipeline.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
from torch.utils.data import DataLoader
1515

16-
from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql
16+
from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql
1717
from graphnet.training.utils import get_predictions, make_dataloader
1818

1919
from graphnet.utilities.logging import get_logger
@@ -97,7 +97,7 @@ def __call__(
9797
df = self._inference(device, dataloader)
9898
truth = self._get_truth(database, event_batches[i].tolist())
9999
retro = self._get_retro(database, event_batches[i].tolist())
100-
self._append_to_pipeline(outdir, truth, retro, df, i)
100+
self._append_to_pipeline(outdir, truth, retro, df)
101101
i += 1
102102
else:
103103
logger.info(outdir)
@@ -210,44 +210,12 @@ def _append_to_pipeline(
210210
truth: pd.DataFrame,
211211
retro: pd.DataFrame,
212212
df: pd.DataFrame,
213-
i: int,
214213
) -> None:
215214
os.makedirs(outdir, exist_ok=True)
216215
pipeline_database = outdir + "/%s.db" % self._pipeline_name
217-
if i == 0:
218-
# Only setup table schemes if its the first time appending
219-
self._create_table(pipeline_database, "reconstruction", df)
220-
self._create_table(pipeline_database, "truth", truth)
221-
save_to_sql(df, "reconstruction", pipeline_database)
222-
save_to_sql(truth, "truth", pipeline_database)
216+
create_table_and_save_to_sql(df, "reconstruction", pipeline_database)
217+
create_table_and_save_to_sql(truth, "truth", pipeline_database)
223218
if isinstance(retro, pd.DataFrame):
224-
if i == 0:
225-
self._create_table(pipeline_database, "retro", retro)
226-
save_to_sql(retro, self._retro_table_name, pipeline_database)
227-
228-
# @FIXME: Duplicate.
229-
def _create_table(
230-
self, pipeline_database: str, table_name: str, df: pd.DataFrame
231-
) -> None:
232-
"""Create a table.
233-
234-
Args:
235-
pipeline_database: Path to the pipeline database.
236-
table_name: Name of the table in pipeline database.
237-
df: DataFrame of combined predictions.
238-
"""
239-
query_columns_list = list()
240-
for column in df.columns:
241-
if column == "event_no":
242-
type_ = "INTEGER PRIMARY KEY NOT NULL"
243-
else:
244-
type_ = "FLOAT"
245-
query_columns_list.append(f"{column} {type_}")
246-
query_columns = ", ".join(query_columns_list)
247-
248-
code = (
249-
"PRAGMA foreign_keys=off;\n"
250-
f"CREATE TABLE {table_name} ({query_columns});\n"
251-
"PRAGMA foreign_keys=on;"
252-
)
253-
run_sql_code(pipeline_database, code)
219+
create_table_and_save_to_sql(
220+
retro, self._retro_table_name, pipeline_database
221+
)

src/graphnet/data/sqlite/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from graphnet.utilities.imports import has_torch_package
44

55
from .sqlite_dataconverter import SQLiteDataConverter
6-
from .sqlite_utilities import run_sql_code, save_to_sql, create_table
6+
from .sqlite_utilities import create_table_and_save_to_sql
77

88
if has_torch_package():
99
from .sqlite_dataset import SQLiteDataset

src/graphnet/data/sqlite/sqlite_dataconverter.py

Lines changed: 28 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from tqdm import tqdm
1111

1212
from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined]
13-
from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql
13+
from graphnet.data.sqlite.sqlite_utilities import (
14+
create_table,
15+
create_table_and_save_to_sql,
16+
)
1417

1518

1619
class SQLiteDataConverter(DataConverter):
@@ -51,7 +54,15 @@ def save_data(self, data: List[OrderedDict], output_file: str) -> None:
5154
saved_any = False
5255
for table, df in dataframe.items():
5356
if len(df) > 0:
54-
save_to_sql(df, table, output_file)
57+
create_table_and_save_to_sql(
58+
df,
59+
table,
60+
output_file,
61+
default_type="FLOAT",
62+
integer_primary_key=not (
63+
is_pulse_map(table) or is_mc_tree(table)
64+
),
65+
)
5566
saved_any = True
5667

5768
if saved_any:
@@ -92,12 +103,14 @@ def merge_files(
92103
input_files, table_name
93104
)
94105
if len(column_names) > 1:
95-
is_pulse_map = is_pulsemap_check(table_name)
96-
self._create_table(
97-
output_file,
98-
table_name,
106+
create_table(
99107
column_names,
100-
is_pulse_map=is_pulse_map,
108+
table_name,
109+
output_file,
110+
default_type="FLOAT",
111+
integer_primary_key=not (
112+
is_pulse_map(table_name) or is_mc_tree(table_name)
113+
),
101114
)
102115

103116
# Merge temporary databases into newly created one
@@ -157,60 +170,6 @@ def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool:
157170
pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps]
158171
return any(d["dom_x"] for d in pulsemap_dicts)
159172

160-
def _attach_index(self, database: str, table_name: str) -> None:
161-
"""Attach the table index.
162-
163-
Important for query times!
164-
"""
165-
code = (
166-
"PRAGMA foreign_keys=off;\n"
167-
"BEGIN TRANSACTION;\n"
168-
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
169-
"COMMIT TRANSACTION;\n"
170-
"PRAGMA foreign_keys=on;"
171-
)
172-
run_sql_code(database, code)
173-
174-
def _create_table(
175-
self,
176-
database: str,
177-
table_name: str,
178-
columns: List[str],
179-
is_pulse_map: bool = False,
180-
) -> None:
181-
"""Create a table.
182-
183-
Args:
184-
database: Path to the database.
185-
table_name: Name of the table.
186-
columns: The names of the columns of the table.
187-
is_pulse_map: Whether or not this is a pulse map table.
188-
"""
189-
query_columns = list()
190-
for column in columns:
191-
if column == "event_no":
192-
if not is_pulse_map:
193-
type_ = "INTEGER PRIMARY KEY NOT NULL"
194-
else:
195-
type_ = "NOT NULL"
196-
else:
197-
type_ = "FLOAT"
198-
query_columns.append(f"{column} {type_}")
199-
query_columns_string = ", ".join(query_columns)
200-
201-
code = (
202-
"PRAGMA foreign_keys=off;\n"
203-
f"CREATE TABLE {table_name} ({query_columns_string});\n"
204-
"PRAGMA foreign_keys=on;"
205-
)
206-
run_sql_code(database, code)
207-
208-
if is_pulse_map:
209-
self.debug(table_name)
210-
self.debug("Attaching indices")
211-
self._attach_index(database, table_name)
212-
return
213-
214173
def _submit_to_database(
215174
self, database: str, key: str, data: pd.DataFrame
216175
) -> None:
@@ -280,9 +239,11 @@ def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame:
280239
return out
281240

282241

283-
def is_pulsemap_check(table_name: str) -> bool:
284-
"""Check whether `table_name` corresponds to a pulsemap."""
285-
if "pulse" in table_name.lower():
286-
return True
287-
else:
288-
return False
242+
def is_pulse_map(table_name: str) -> bool:
243+
"""Check whether `table_name` corresponds to a pulse map."""
244+
return "pulse" in table_name.lower() or "series" in table_name.lower()
245+
246+
247+
def is_mc_tree(table_name: str) -> bool:
248+
"""Check whether `table_name` corresponds to an MC tree."""
249+
return "I3MCTree" in table_name
Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,117 @@
11
"""SQLite-specific utility functions for use in `graphnet.data`."""
22

3+
import os.path
4+
from typing import List
5+
36
import pandas as pd
47
import sqlalchemy
58
import sqlite3
69

710

8-
def run_sql_code(database: str, code: str) -> None:
11+
def database_exists(database_path: str) -> bool:
12+
"""Check whether database exists at `database_path`."""
13+
assert database_path.endswith(
14+
".db"
15+
), "Provided database path does not end in `.db`."
16+
return os.path.exists(database_path)
17+
18+
19+
def database_table_exists(database_path: str, table_name: str) -> bool:
20+
"""Check whether `table_name` exists in database at `database_path`."""
21+
if not database_exists(database_path):
22+
return False
23+
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"
24+
with sqlite3.connect(database_path) as conn:
25+
result = pd.read_sql(query, conn)
26+
return len(result) == 1
27+
28+
29+
def run_sql_code(database_path: str, code: str) -> None:
930
"""Execute SQLite code.
1031
1132
Args:
12-
database: Path to databases
33+
database_path: Path to databases
1334
code: SQLite code
1435
"""
15-
conn = sqlite3.connect(database)
36+
conn = sqlite3.connect(database_path)
1637
c = conn.cursor()
1738
c.executescript(code)
1839
c.close()
1940

2041

21-
def save_to_sql(df: pd.DataFrame, table_name: str, database: str) -> None:
42+
def save_to_sql(df: pd.DataFrame, table_name: str, database_path: str) -> None:
2243
"""Save a dataframe `df` to a table `table_name` in SQLite `database`.
2344
2445
Table must exist already.
2546
2647
Args:
2748
df: Dataframe with data to be stored in sqlite table
2849
table_name: Name of table. Must exist already
29-
database: Path to SQLite database
50+
database_path: Path to SQLite database
3051
"""
31-
engine = sqlalchemy.create_engine("sqlite:///" + database)
52+
engine = sqlalchemy.create_engine("sqlite:///" + database_path)
3253
df.to_sql(table_name, con=engine, index=False, if_exists="append")
3354
engine.dispose()
3455

3556

36-
def attach_index(database: str, table_name: str) -> None:
37-
"""Attaches the table index.
57+
def attach_index(
58+
database_path: str, table_name: str, index_column: str = "event_no"
59+
) -> None:
60+
"""Attach the table (i.e., event) index.
3861
3962
Important for query times!
4063
"""
4164
code = (
4265
"PRAGMA foreign_keys=off;\n"
4366
"BEGIN TRANSACTION;\n"
44-
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
67+
f"CREATE INDEX {index_column}_{table_name} "
68+
f"ON {table_name} ({index_column});\n"
4569
"COMMIT TRANSACTION;\n"
4670
"PRAGMA foreign_keys=on;"
4771
)
48-
run_sql_code(database, code)
72+
run_sql_code(database_path, code)
4973

5074

5175
def create_table(
52-
df: pd.DataFrame,
76+
columns: List[str],
5377
table_name: str,
5478
database_path: str,
55-
is_pulse_map: bool = False,
79+
*,
80+
index_column: str = "event_no",
81+
default_type: str = "NOT NULL",
82+
integer_primary_key: bool = True,
5683
) -> None:
5784
"""Create a table.
5885
5986
Args:
60-
df: Data to be saved to table
87+
columns: Column names to be created in table.
6188
table_name: Name of the table.
6289
database_path: Path to the database.
63-
is_pulse_map: Whether or not this is a pulse map table.
90+
index_column: Name of the index column.
91+
default_type: The type used for all non-index columns.
92+
integer_primary_key: Whether or not to create the `index_column` with
93+
the `INTEGER PRIMARY KEY` type. Such a column is required to have
94+
unique, integer values for each row. This is appropriate when the
95+
table has one row per event, e.g., event-level MC truth. It is not
96+
appropriate for pulse map series, particle-level MC truth, and
97+
other such data that is expected to have more that one row per
98+
event (i.e., with the same index).
6499
"""
65-
query_columns = list()
66-
for column in df.columns:
67-
if column == "event_no":
68-
if not is_pulse_map:
100+
# Prepare column names and types
101+
query_columns = []
102+
for column in columns:
103+
type_ = default_type
104+
if column == index_column:
105+
if integer_primary_key:
69106
type_ = "INTEGER PRIMARY KEY NOT NULL"
70107
else:
71108
type_ = "NOT NULL"
72-
else:
73-
type_ = "NOT NULL"
109+
74110
query_columns.append(f"{column} {type_}")
111+
75112
query_columns_string = ", ".join(query_columns)
76113

114+
# Run SQL code
77115
code = (
78116
"PRAGMA foreign_keys=off;\n"
79117
f"CREATE TABLE {table_name} ({query_columns_string});\n"
@@ -83,3 +121,29 @@ def create_table(
83121
database_path,
84122
code,
85123
)
124+
125+
# Attaching index to all non-truth-like tables (e.g., pulse maps).
126+
if not integer_primary_key:
127+
attach_index(database_path, table_name)
128+
129+
130+
def create_table_and_save_to_sql(
131+
df: pd.DataFrame,
132+
table_name: str,
133+
database_path: str,
134+
*,
135+
index_column: str = "event_no",
136+
default_type: str = "NOT NULL",
137+
integer_primary_key: bool = True,
138+
) -> None:
139+
"""Create table if it doesn't exist and save dataframe to it."""
140+
if not database_table_exists(database_path, table_name):
141+
create_table(
142+
df.columns,
143+
table_name,
144+
database_path,
145+
index_column=index_column,
146+
default_type=default_type,
147+
integer_primary_key=integer_primary_key,
148+
)
149+
save_to_sql(df, table_name=table_name, database_path=database_path)

0 commit comments

Comments
 (0)