77import pymysql as client
88import logging
99from getpass import getpass
10- from pymysql import err
1110
1211from .settings import config
13- from .errors import DataJointError , server_error_codes , is_connection_error
12+ from . import errors
1413from .dependencies import Dependencies
1514
15+ # client errors to catch
16+ client_errors = (client .err .InterfaceError , client .err .DatabaseError )
17+
18+
19+ def translate_query_error (client_error , query ):
20+ """
21+ Take client error and original query and return the corresponding DataJoint exception.
22+ :param client_error: the exception raised by the client interface
23+ :param query: sql query with placeholders
24+ :return: an instance of the corresponding subclass of datajoint.errors.DataJointError
25+ """
26+ # Loss of connection errors
27+ if isinstance (client_error , client .err .InterfaceError ) and client_error .args [0 ] == "(0, '')" :
28+ return errors .LostConnectionError ('Server connection lost due to an interface error.' , * client_error .args [1 :])
29+ disconnect_codes = {
30+ 2006 : "Connection timed out" ,
31+ 2013 : "Server connection lost" }
32+ if isinstance (client_error , client .err .OperationalError ) and client_error .args [0 ] in disconnect_codes :
33+ return errors .LostConnectionError (disconnect_codes [client_error .args [0 ]], * client_error .args [1 :])
34+ # Access errors
35+ if isinstance (client_error , client .err .OperationalError ) and client_error .args [0 ] in (1044 , 1142 ):
36+ return errors .AccessError ('Insufficient privileges.' , client_error .args [1 ], query )
37+ # Integrity errors
38+ if isinstance (client_error , client .err .IntegrityError ) and client_error .args [0 ] == 1062 :
39+ return errors .DuplicateError (* client_error .args [1 :])
40+ if isinstance (client_error , client .err .IntegrityError ) and client_error .args [0 ] == 1452 :
41+ return errors .IntegrityError (* client_error .args [1 :])
42+ # Syntax Errors
43+ if isinstance (client_error , client .err .ProgrammingError ) and client_error .args [0 ] == 1064 :
44+ return errors .QuerySyntaxError (client_error .args [1 ], query )
45+ # Existence Errors
46+ if isinstance (client_error , client .err .ProgrammingError ) and client_error .args [0 ] == 1146 :
47+ return errors .MissingTableError (client_error .args [1 ], query )
48+ if isinstance (client_error , client .err .InternalError ) and client_error .args [0 ] == 1364 :
49+ return errors .MissingAttributeError (* client_error .args [1 :])
50+ raise client_error
51+
1652
1753logger = logging .getLogger (__name__ )
1854
@@ -60,6 +96,7 @@ class Connection:
6096 :param init_fun: connection initialization function (SQL)
6197 :param use_tls: TLS encryption option
6298 """
99+
63100 def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ):
64101 if ':' in host :
65102 # the port in the hostname overrides the port argument
@@ -79,7 +116,7 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
79116 logger .info ("Connected {user}@{host}:{port}" .format (** self .conn_info ))
80117 self .connection_id = self .query ('SELECT connection_id()' ).fetchone ()[0 ]
81118 else :
82- raise DataJointError ('Connection failed.' )
119+ raise errors . ConnectionError ('Connection failed.' )
83120 self ._in_transaction = False
84121 self .schemas = dict ()
85122 self .dependencies = Dependencies (self )
@@ -103,16 +140,16 @@ def connect(self):
103140 self ._conn = client .connect (
104141 init_command = self .init_fun ,
105142 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
106- "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
143+ "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
107144 charset = config ['connection.charset' ],
108145 ** self .conn_info )
109- except err .InternalError :
146+ except client . err .InternalError :
110147 if ssl_input is None :
111148 self .conn_info .pop ('ssl' )
112149 self ._conn = client .connect (
113150 init_command = self .init_fun ,
114151 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
115- "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
152+ "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
116153 charset = config ['connection.charset' ],
117154 ** self .conn_info )
118155 self .conn_info ['ssl_input' ] = ssl_input
@@ -141,50 +178,46 @@ def is_connected(self):
141178 return False
142179 return True
143180
144- def query (self , query , args = (), as_dict = False , suppress_warnings = True , reconnect = None ):
181+ @staticmethod
182+ def __execute_query (cursor , query , args , cursor_class , suppress_warnings ):
183+ try :
184+ with warnings .catch_warnings ():
185+ if suppress_warnings :
186+ # suppress all warnings arising from underlying SQL library
187+ warnings .simplefilter ("ignore" )
188+ cursor .execute (query , args )
189+ except client_errors as err :
190+ raise translate_query_error (err , query )
191+
192+ def query (self , query , args = (), * , as_dict = False , suppress_warnings = True , reconnect = None ):
145193 """
146194 Execute the specified query and return the tuple generator (cursor).
147-
148- :param query: mysql query
195+ :param query: SQL query
149196 :param args: additional arguments for the client.cursor
150197 :param as_dict: If as_dict is set to True, the returned cursor objects returns
151198 query results as dictionary.
152199 :param suppress_warnings: If True, suppress all warnings arising from underlying query library
200+ :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
153201 """
154202 if reconnect is None :
155203 reconnect = config ['database.reconnect' ]
156-
157- cursor = client .cursors .DictCursor if as_dict else client .cursors .Cursor
158- cur = self ._conn .cursor (cursor = cursor )
159-
160204 logger .debug ("Executing SQL:" + query [0 :300 ])
205+ cursor_class = client .cursors .DictCursor if as_dict else client .cursors .Cursor
206+ cursor = self ._conn .cursor (cursor = cursor_class )
161207 try :
162- with warnings .catch_warnings ():
163- if suppress_warnings :
164- # suppress all warnings arising from underlying SQL library
165- warnings .simplefilter ("ignore" )
166- cur .execute (query , args )
167- except (err .InterfaceError , err .OperationalError ) as e :
168- if is_connection_error (e ) and reconnect :
169- warnings .warn ("Mysql server has gone away. Reconnecting to the server." )
170- self .connect ()
171- if self ._in_transaction :
172- self .cancel_transaction ()
173- raise DataJointError ("Connection was lost during a transaction." ) from None
174- else :
175- logger .debug ("Re-executing SQL" )
176- cur = self .query (query , args = args , as_dict = as_dict , suppress_warnings = suppress_warnings , reconnect = False )
177- else :
178- logger .debug ("Caught InterfaceError/OperationalError." )
208+ self .__execute_query (cursor , query , args , cursor_class , suppress_warnings )
209+ except errors .LostConnectionError :
210+ if not reconnect :
179211 raise
180- except err .ProgrammingError as e :
181- if e .args [0 ] == server_error_codes ['parse error' ]:
182- raise DataJointError ("\n " .join ((
183- "Error in query:" , query ,
184- "Please check spelling, syntax, and existence of tables and attributes." ,
185- "When restricting a relation by a condition in a string, enclose attributes in backquotes."
186- ))) from None
187- return cur
212+ warnings .warn ("MySQL server has gone away. Reconnecting to the server." )
213+ self .connect ()
214+ if self ._in_transaction :
215+ self .cancel_transaction ()
216+ raise errors .LostConnectionError ("Connection was lost during a transaction." ) from None
217+ logger .debug ("Re-executing" )
218+ cursor = self ._conn .cursor (cursor = cursor_class )
219+ self .__execute_query (cursor , query , args , cursor_class , suppress_warnings )
220+ return cursor
188221
189222 def get_user (self ):
190223 """
@@ -204,11 +237,9 @@ def in_transaction(self):
204237 def start_transaction (self ):
205238 """
206239 Starts a transaction error.
207-
208- :raise DataJointError: if there is an ongoing transaction.
209240 """
210241 if self .in_transaction :
211- raise DataJointError ("Nested connections are not supported." )
242+ raise errors . DataJointError ("Nested connections are not supported." )
212243 self .query ('START TRANSACTION WITH CONSISTENT SNAPSHOT' )
213244 self ._in_transaction = True
214245 logger .info ("Transaction started" )
@@ -252,3 +283,4 @@ def transaction(self):
252283 raise
253284 else :
254285 self .commit_transaction ()
286+
0 commit comments