11import sqlite3
2- from policyengine_api .constants import REPO , VERSION , COUNTRY_PACKAGE_VERSIONS
2+ from policyengine_api .constants import REPO , COUNTRY_PACKAGE_VERSIONS
33from policyengine_api .utils import hash_object
44from pathlib import Path
55from dotenv import load_dotenv
@@ -41,6 +41,29 @@ def fetchall(self):
4141 return remaining
4242
4343
44+ class _TransactionProxy :
45+ """Execute queries against an existing connection inside a transaction."""
46+
47+ def __init__ (self , connection , local : bool ):
48+ self ._connection = connection
49+ self ._local = local
50+
51+ def query (self , * query ):
52+ if self ._local :
53+ cursor = self ._connection .cursor ()
54+ return cursor .execute (* query )
55+
56+ query = list (query )
57+ main_query = query [0 ].replace ("?" , "%s" )
58+ query [0 ] = main_query
59+ params = query [1 ] if len (query ) > 1 else None
60+ if params is not None :
61+ result = self ._connection .exec_driver_sql (main_query , params )
62+ else :
63+ result = self ._connection .exec_driver_sql (main_query )
64+ return _ResultProxy (result )
65+
66+
4467class PolicyEngineDatabase :
4568 """
4669 A wrapper around the database connection.
@@ -50,6 +73,13 @@ class PolicyEngineDatabase:
5073
5174 household_cache : dict = {}
5275
76+ @staticmethod
77+ def _dict_factory (cursor , row ):
78+ d = {}
79+ for idx , col in enumerate (cursor .description ):
80+ d [col [0 ]] = row [idx ]
81+ return d
82+
5383 def __init__ (
5484 self ,
5585 local : bool = False ,
@@ -91,7 +121,7 @@ def _close_pool(self):
91121 try :
92122 self .pool .dispose ()
93123 self .connector .close ()
94- except :
124+ except Exception :
95125 pass
96126
97127 def _execute_remote (self , query_args ):
@@ -110,17 +140,22 @@ def _execute_remote(self, query_args):
110140 # connection context closing
111141 return _ResultProxy (result )
112142
143+ def _execute_remote_transaction (self , callback ):
144+ with self .pool .connect () as conn :
145+ transaction = conn .begin ()
146+ proxy = _TransactionProxy (conn , local = False )
147+ try :
148+ result = callback (proxy )
149+ transaction .commit ()
150+ return result
151+ except Exception :
152+ transaction .rollback ()
153+ raise
154+
113155 def query (self , * query ):
114156 if self .local :
115157 with sqlite3 .connect (self .db_url ) as conn :
116-
117- def dict_factory (cursor , row ):
118- d = {}
119- for idx , col in enumerate (cursor .description ):
120- d [col [0 ]] = row [idx ]
121- return d
122-
123- conn .row_factory = dict_factory
158+ conn .row_factory = self ._dict_factory
124159 cursor = conn .cursor ()
125160 return cursor .execute (* query )
126161 else :
@@ -134,14 +169,44 @@ def dict_factory(cursor, row):
134169 except (
135170 sqlalchemy .exc .InterfaceError ,
136171 sqlalchemy .exc .OperationalError ,
137- ) as e :
172+ ):
138173 try :
139174 self ._close_pool ()
140175 self ._create_pool ()
141176 return self ._execute_remote (query )
142177 except Exception as e :
143178 raise e
144179
180+ def transaction (self , callback ):
181+ if self .local :
182+ connection = getattr (self , "_connection" , None )
183+ owns_connection = connection is None
184+ if owns_connection :
185+ connection = sqlite3 .connect (self .db_url )
186+ connection .row_factory = self ._dict_factory
187+ try :
188+ connection .execute ("BEGIN IMMEDIATE" )
189+ proxy = _TransactionProxy (connection , local = True )
190+ result = callback (proxy )
191+ connection .commit ()
192+ return result
193+ except Exception :
194+ connection .rollback ()
195+ raise
196+ finally :
197+ if owns_connection :
198+ connection .close ()
199+
200+ try :
201+ return self ._execute_remote_transaction (callback )
202+ except (
203+ sqlalchemy .exc .InterfaceError ,
204+ sqlalchemy .exc .OperationalError ,
205+ ):
206+ self ._close_pool ()
207+ self ._create_pool ()
208+ return self ._execute_remote_transaction (callback )
209+
145210 def initialize (self ):
146211 """
147212 Create the database tables.
@@ -175,7 +240,7 @@ def initialize(self):
175240 range (1 , 1 + len (COUNTRY_PACKAGE_VERSIONS )),
176241 ):
177242 self .query (
178- f "INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)" ,
243+ "INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)" ,
179244 (
180245 policy_id ,
181246 country_id ,
0 commit comments