77import pymysql as client
88import logging
99from getpass import getpass
10- from pymysql import err
1110
1211from .settings import config
13- from .errors import DataJointError , server_error_codes , is_connection_error
12+ from . import errors
1413from .dependencies import Dependencies
1514
15+ # client errors to catch
16+ client_errors = (client .err .InterfaceError , client .err .DatabaseError )
17+
18+
19+ def translate_query_error (client_error , query , args ):
20+ """
21+ Take client error and original query and return the corresponding DataJoint exception.
22+ :param client_error: the exception raised by the client interface
23+ :param query: sql query with placeholders
24+ :param args: values for query placeholders
25+ :return: an instance of the corresponding subclass of datajoint.errors.DataJointError
26+ """
27+ # Loss of connection errors
28+ if isinstance (client_error , client .err .InterfaceError ) and client_error .args [0 ] == "(0, '')" :
29+ return errors .LostConnectionError ('Server connection lost due to an interface error.' , * client_error .args [1 :])
30+ disconnect_codes = {
31+ 2006 : "Connection timed out" ,
32+ 2013 : "Server connection lost" }
33+ if isinstance (client_error , client .err .OperationalError ) and client_error .args [0 ] in disconnect_codes :
34+ return errors .LostConnectionError (disconnect_codes [client_error .args [0 ]], * client_error .args [1 :])
35+ # Access errors
36+ if isinstance (client_error , client .err .OperationalError ) and client_error .args [0 ] in (1044 , 1142 ):
37+ return errors .AccessError ('Insufficient privileges.' , * client_error .args [1 :], query )
38+ # Integrity errors
39+ if isinstance (client_error , client .err .IntegrityError ) and client_error .args [0 ] == 1062 :
40+ return errors .DuplicateError (* client_error .args [1 :])
41+ if isinstance (client_error , client .err .IntegrityError ) and client_error .args [0 ] == 1452 :
42+ return errors .IntegrityError (* client_error .args [1 :])
43+ # Syntax Errors
44+ if isinstance (client_error , client .err .ProgrammingError ) and client_error .args [0 ] == 1064 :
45+ return errors .QuerySyntaxError (* client_error .args [1 :], query )
46+ # Existence Errors
47+ if isinstance (client_error , client .err .ProgrammingError ) and client_error .args [0 ] == 1146 :
48+ return errors .MissingTableError (* args [1 :], query )
49+ if isinstance (client_error , client .err .InternalError ) and client_error .args [0 ] == 1364 :
50+ return errors .MissingAttributeValueError (* args [1 :])
51+ raise client_error
52+
1653
1754logger = logging .getLogger (__name__ )
1855
@@ -60,6 +97,7 @@ class Connection:
6097 :param init_fun: connection initialization function (SQL)
6198 :param use_tls: TLS encryption option
6299 """
100+
63101 def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ):
64102 if ':' in host :
65103 # the port in the hostname overrides the port argument
@@ -79,7 +117,7 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
79117 logger .info ("Connected {user}@{host}:{port}" .format (** self .conn_info ))
80118 self .connection_id = self .query ('SELECT connection_id()' ).fetchone ()[0 ]
81119 else :
82- raise DataJointError ('Connection failed.' )
120+ raise errors . ConnectionError ('Connection failed.' )
83121 self ._in_transaction = False
84122 self .schemas = dict ()
85123 self .dependencies = Dependencies (self )
@@ -103,16 +141,16 @@ def connect(self):
103141 self ._conn = client .connect (
104142 init_command = self .init_fun ,
105143 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
106- "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
144+ "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
107145 charset = config ['connection.charset' ],
108146 ** self .conn_info )
109- except err .InternalError :
147+ except client . err .InternalError :
110148 if ssl_input is None :
111149 self .conn_info .pop ('ssl' )
112150 self ._conn = client .connect (
113151 init_command = self .init_fun ,
114152 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
115- "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
153+ "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
116154 charset = config ['connection.charset' ],
117155 ** self .conn_info )
118156 self .conn_info ['ssl_input' ] = ssl_input
@@ -141,50 +179,46 @@ def is_connected(self):
141179 return False
142180 return True
143181
144- def query (self , query , args = (), as_dict = False , suppress_warnings = True , reconnect = None ):
182+ @staticmethod
183+ def __execute_query (cursor , query , args , cursor_class , suppress_warnings ):
184+ try :
185+ with warnings .catch_warnings ():
186+ if suppress_warnings :
187+ # suppress all warnings arising from underlying SQL library
188+ warnings .simplefilter ("ignore" )
189+ cursor .execute (query , args )
190+ except client_errors as err :
191+ raise translate_query_error (err , query , args )
192+
193+ def query (self , query , args = (), * , as_dict = False , suppress_warnings = True , reconnect = None ):
145194 """
146195 Execute the specified query and return the tuple generator (cursor).
147-
148- :param query: mysql query
196+ :param query: SQL query
149197 :param args: additional arguments for the client.cursor
150198 :param as_dict: If as_dict is set to True, the returned cursor objects returns
151199 query results as dictionary.
152200 :param suppress_warnings: If True, suppress all warnings arising from underlying query library
201+ :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
153202 """
154203 if reconnect is None :
155204 reconnect = config ['database.reconnect' ]
156-
157- cursor = client .cursors .DictCursor if as_dict else client .cursors .Cursor
158- cur = self ._conn .cursor (cursor = cursor )
159-
160205 logger .debug ("Executing SQL:" + query [0 :300 ])
206+ cursor_class = client .cursors .DictCursor if as_dict else client .cursors .Cursor
207+ cursor = self ._conn .cursor (cursor = cursor_class )
161208 try :
162- with warnings .catch_warnings ():
163- if suppress_warnings :
164- # suppress all warnings arising from underlying SQL library
165- warnings .simplefilter ("ignore" )
166- cur .execute (query , args )
167- except (err .InterfaceError , err .OperationalError ) as e :
168- if is_connection_error (e ) and reconnect :
169- warnings .warn ("Mysql server has gone away. Reconnecting to the server." )
170- self .connect ()
171- if self ._in_transaction :
172- self .cancel_transaction ()
173- raise DataJointError ("Connection was lost during a transaction." ) from None
174- else :
175- logger .debug ("Re-executing SQL" )
176- cur = self .query (query , args = args , as_dict = as_dict , suppress_warnings = suppress_warnings , reconnect = False )
177- else :
178- logger .debug ("Caught InterfaceError/OperationalError." )
209+ self .__execute_query (cursor , query , args , cursor_class , suppress_warnings )
210+ except errors .LostConnectionError :
211+ if not reconnect :
179212 raise
180- except err .ProgrammingError as e :
181- if e .args [0 ] == server_error_codes ['parse error' ]:
182- raise DataJointError ("\n " .join ((
183- "Error in query:" , query ,
184- "Please check spelling, syntax, and existence of tables and attributes." ,
185- "When restricting a relation by a condition in a string, enclose attributes in backquotes."
186- ))) from None
187- return cur
213+ warnings .warn ("MySQL server has gone away. Reconnecting to the server." )
214+ self .connect ()
215+ if self ._in_transaction :
216+ self .cancel_transaction ()
217+ raise errors .LostConnectionError ("Connection was lost during a transaction." ) from None
218+ logger .debug ("Re-executing" )
219+ cursor = self ._conn .cursor (cursor = cursor_class )
220+ self .__execute_query (cursor , query , args , cursor_class , suppress_warnings )
221+ return cursor
188222
189223 def get_user (self ):
190224 """
@@ -204,11 +238,9 @@ def in_transaction(self):
204238 def start_transaction (self ):
205239 """
206240 Starts a transaction error.
207-
208- :raise DataJointError: if there is an ongoing transaction.
209241 """
210242 if self .in_transaction :
211- raise DataJointError ("Nested connections are not supported." )
243+ raise errors . DataJointError ("Nested connections are not supported." )
212244 self .query ('START TRANSACTION WITH CONSISTENT SNAPSHOT' )
213245 self ._in_transaction = True
214246 logger .info ("Transaction started" )
0 commit comments