Skip to content

Commit 9a6db29

Browse files
refactor error handling
1 parent edd4779 commit 9a6db29

File tree

10 files changed

+187
-102
lines changed

10 files changed

+187
-102
lines changed

datajoint/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
'Not', 'AndList', 'U', 'Diagram', 'Di', 'ERD',
2525
'set_password', 'kill',
2626
'MatCell', 'MatStruct',
27-
'DataJointError', 'DuplicateError', 'key']
28-
27+
'errors', 'DataJointError', 'key']
2928

3029
from .version import __version__
3130
from .settings import config
@@ -38,7 +37,8 @@
3837
from .diagram import Diagram
3938
from .admin import set_password, kill
4039
from .blob import MatCell, MatStruct
41-
from .errors import DataJointError, DuplicateError
4240
from .fetch import key
41+
from . import errors
42+
from .errors import DataJointError
4343

4444
ERD = Di = Diagram # Aliases for Diagram

datajoint/connection.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,49 @@
77
import pymysql as client
88
import logging
99
from getpass import getpass
10-
from pymysql import err
1110

1211
from .settings import config
13-
from .errors import DataJointError, server_error_codes, is_connection_error
12+
from . import errors
1413
from .dependencies import Dependencies
1514

15+
# client errors to catch
16+
client_errors = (client.err.InterfaceError, client.err.DatabaseError)
17+
18+
19+
def translate_query_error(client_error, query, args):
20+
"""
21+
Take client error and original query and return the corresponding DataJoint exception.
22+
:param client_error: the exception raised by the client interface
23+
:param query: sql query with placeholders
24+
:param args: values for query placeholders
25+
:return: an instance of the corresponding subclass of datajoint.errors.DataJointError
26+
"""
27+
# Loss of connection errors
28+
if isinstance(client_error, client.err.InterfaceError) and client_error.args[0] == "(0, '')":
29+
return errors.LostConnectionError('Server connection lost due to an interface error.', *client_error.args[1:])
30+
disconnect_codes = {
31+
2006: "Connection timed out",
32+
2013: "Server connection lost"}
33+
if isinstance(client_error, client.err.OperationalError) and client_error.args[0] in disconnect_codes:
34+
return errors.LostConnectionError(disconnect_codes[client_error.args[0]], *client_error.args[1:])
35+
# Access errors
36+
if isinstance(client_error, client.err.OperationalError) and client_error.args[0] in (1044, 1142):
37+
return errors.AccessError('Insufficient privileges.', *client_error.args[1:], query)
38+
# Integrity errors
39+
if isinstance(client_error, client.err.IntegrityError) and client_error.args[0] == 1062:
40+
return errors.DuplicateError(*client_error.args[1:])
41+
if isinstance(client_error, client.err.IntegrityError) and client_error.args[0] == 1452:
42+
return errors.IntegrityError(*client_error.args[1:])
43+
# Syntax Errors
44+
if isinstance(client_error, client.err.ProgrammingError) and client_error.args[0] == 1064:
45+
return errors.QuerySyntaxError(*client_error.args[1:], query)
46+
# Existence Errors
47+
if isinstance(client_error, client.err.ProgrammingError) and client_error.args[0] == 1146:
48+
return errors.MissingTableError(*args[1:], query)
49+
if isinstance(client_error, client.err.InternalError) and client_error.args[0] == 1364:
50+
return errors.MissingAttributeValueError(*args[1:])
51+
raise client_error
52+
1653

1754
logger = logging.getLogger(__name__)
1855

@@ -60,6 +97,7 @@ class Connection:
6097
:param init_fun: connection initialization function (SQL)
6198
:param use_tls: TLS encryption option
6299
"""
100+
63101
def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None):
64102
if ':' in host:
65103
# the port in the hostname overrides the port argument
@@ -79,7 +117,7 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
79117
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
80118
self.connection_id = self.query('SELECT connection_id()').fetchone()[0]
81119
else:
82-
raise DataJointError('Connection failed.')
120+
raise errors.ConnectionError('Connection failed.')
83121
self._in_transaction = False
84122
self.schemas = dict()
85123
self.dependencies = Dependencies(self)
@@ -103,16 +141,16 @@ def connect(self):
103141
self._conn = client.connect(
104142
init_command=self.init_fun,
105143
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
106-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
144+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
107145
charset=config['connection.charset'],
108146
**self.conn_info)
109-
except err.InternalError:
147+
except client.err.InternalError:
110148
if ssl_input is None:
111149
self.conn_info.pop('ssl')
112150
self._conn = client.connect(
113151
init_command=self.init_fun,
114152
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
115-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
153+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
116154
charset=config['connection.charset'],
117155
**self.conn_info)
118156
self.conn_info['ssl_input'] = ssl_input
@@ -141,50 +179,46 @@ def is_connected(self):
141179
return False
142180
return True
143181

144-
def query(self, query, args=(), as_dict=False, suppress_warnings=True, reconnect=None):
182+
@staticmethod
183+
def __execute_query(cursor, query, args, cursor_class, suppress_warnings):
184+
try:
185+
with warnings.catch_warnings():
186+
if suppress_warnings:
187+
# suppress all warnings arising from underlying SQL library
188+
warnings.simplefilter("ignore")
189+
cursor.execute(query, args)
190+
except client_errors as err:
191+
raise translate_query_error(err, query, args)
192+
193+
def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None):
145194
"""
146195
Execute the specified query and return the tuple generator (cursor).
147-
148-
:param query: mysql query
196+
:param query: SQL query
149197
:param args: additional arguments for the client.cursor
150198
:param as_dict: If as_dict is set to True, the returned cursor objects returns
151199
query results as dictionary.
152200
:param suppress_warnings: If True, suppress all warnings arising from underlying query library
201+
:param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
153202
"""
154203
if reconnect is None:
155204
reconnect = config['database.reconnect']
156-
157-
cursor = client.cursors.DictCursor if as_dict else client.cursors.Cursor
158-
cur = self._conn.cursor(cursor=cursor)
159-
160205
logger.debug("Executing SQL:" + query[0:300])
206+
cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
207+
cursor = self._conn.cursor(cursor=cursor_class)
161208
try:
162-
with warnings.catch_warnings():
163-
if suppress_warnings:
164-
# suppress all warnings arising from underlying SQL library
165-
warnings.simplefilter("ignore")
166-
cur.execute(query, args)
167-
except (err.InterfaceError, err.OperationalError) as e:
168-
if is_connection_error(e) and reconnect:
169-
warnings.warn("Mysql server has gone away. Reconnecting to the server.")
170-
self.connect()
171-
if self._in_transaction:
172-
self.cancel_transaction()
173-
raise DataJointError("Connection was lost during a transaction.") from None
174-
else:
175-
logger.debug("Re-executing SQL")
176-
cur = self.query(query, args=args, as_dict=as_dict, suppress_warnings=suppress_warnings, reconnect=False)
177-
else:
178-
logger.debug("Caught InterfaceError/OperationalError.")
209+
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
210+
except errors.LostConnectionError:
211+
if not reconnect:
179212
raise
180-
except err.ProgrammingError as e:
181-
if e.args[0] == server_error_codes['parse error']:
182-
raise DataJointError("\n".join((
183-
"Error in query:", query,
184-
"Please check spelling, syntax, and existence of tables and attributes.",
185-
"When restricting a relation by a condition in a string, enclose attributes in backquotes."
186-
))) from None
187-
return cur
213+
warnings.warn("MySQL server has gone away. Reconnecting to the server.")
214+
self.connect()
215+
if self._in_transaction:
216+
self.cancel_transaction()
217+
raise errors.LostConnectionError("Connection was lost during a transaction.") from None
218+
logger.debug("Re-executing")
219+
cursor = self._conn.cursor(cursor=cursor_class)
220+
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
221+
return cursor
188222

189223
def get_user(self):
190224
"""
@@ -204,11 +238,9 @@ def in_transaction(self):
204238
def start_transaction(self):
205239
"""
206240
Starts a transaction error.
207-
208-
:raise DataJointError: if there is an ongoing transaction.
209241
"""
210242
if self.in_transaction:
211-
raise DataJointError("Nested connections are not supported.")
243+
raise errors.DataJointError("Nested connections are not supported.")
212244
self.query('START TRANSACTION WITH CONSISTENT SNAPSHOT')
213245
self._in_transaction = True
214246
logger.info("Transaction started")

datajoint/declare.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
210210
# declare the foreign key
211211
foreign_key_sql.append(
212212
'FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT'.format(
213-
fk='`,`'.join(ref.primary_key),
214-
pk='`,`'.join(base.primary_key),
213+
fk='`,`'.join(ref.primary_key) or '_', # dimensionless tables use _ as primary key
214+
pk='`,`'.join(base.primary_key) or '_',
215215
ref=base.full_table_name))
216216

217217
# declare unique index
@@ -235,7 +235,7 @@ def prepare_declare(definition, context):
235235
external_stores = []
236236

237237
for line in definition:
238-
if line.startswith('#'): # ignore additional comments
238+
if not line or line.startswith('#'): # ignore additional comments
239239
pass
240240
elif line.startswith('---') or line.startswith('___'):
241241
in_key = False # start parsing dependent attributes
@@ -277,7 +277,9 @@ def declare(full_table_name, definition, context):
277277
definition, context)
278278

279279
if not primary_key:
280-
raise DataJointError('Table must have a primary key')
280+
# singular (dimensionless) table -- can contain only one element
281+
attribute_sql.insert(0, '`_` char(1) not null default "" COMMENT "dimensionless primary key"')
282+
primary_key = ['_']
281283

282284
return (
283285
'CREATE TABLE IF NOT EXISTS %s (\n' % full_table_name +

datajoint/errors.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,72 @@
1-
from pymysql import err
1+
"""
2+
Exception classes for the DataJoint library
3+
"""
24

3-
server_error_codes = {
4-
'unknown column': 1054,
5-
'duplicate entry': 1062,
6-
'parse error': 1064,
7-
'command denied': 1142,
8-
'table does not exist': 1146,
9-
'syntax error': 1149,
10-
}
115

12-
operation_error_codes = {
13-
'connection timedout': 2006,
14-
'lost connection': 2013,
15-
}
6+
# --- Top Level ---
7+
class DataJointError(Exception):
8+
"""
9+
Base class for errors specific to DataJoint internal operation.
10+
"""
11+
def suggest(self, *args):
12+
"""
13+
regenerate the exception with additional arguments
14+
:param args: addition arguments
15+
:return: a new exception of the same type with the additional arguments
16+
"""
17+
return self.__class__(*self.args, *args)
1618

1719

18-
def is_connection_error(e):
20+
# --- Second Level ---
21+
class LostConnectionError(DataJointError):
1922
"""
20-
Checks if error e pertains to a connection issue
23+
Loss of server connection
2124
"""
22-
return (isinstance(e, err.InterfaceError) and e.args[0] == "(0, '')") or\
23-
(isinstance(e, err.OperationalError) and e.args[0] in operation_error_codes.values())
2425

2526

26-
class DataJointError(Exception):
27+
class QueryError(DataJointError):
2728
"""
28-
Base class for errors specific to DataJoint internal operation.
29+
Errors arising from queries to the database
2930
"""
30-
pass
3131

3232

33-
class DuplicateError(DataJointError):
33+
# --- Third Level: QueryErrors ---
34+
class QuerySyntaxError(QueryError):
35+
"""
36+
Errors arising from incorrect query syntax
37+
"""
38+
39+
40+
class AccessError(QueryError):
41+
"""
42+
User access error: insufficient privileges.
43+
"""
44+
45+
46+
class UnknownAttributeError(DataJointError):
47+
"""
48+
Error caused by referencing to a non-existing attributes
49+
"""
50+
51+
52+
class MissingTableError(DataJointError):
53+
"""
54+
Query on a table that has not been declared
55+
"""
56+
57+
58+
class DuplicateError(QueryError):
59+
"""
60+
An integrity error caused by a duplicate entry into a unique key
61+
"""
62+
63+
64+
class IntegrityError(QueryError):
65+
"""
66+
An integrity error triggered by foreign key constraints
67+
"""
68+
69+
class MissingAttributeValueError(QueryError):
3470
"""
35-
Error caused by a violation of a unique constraint when inserting data
71+
An error arising when a required attribute value is not provided in INSERT
3672
"""
37-
pass

datajoint/heading.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ def init_from_database(self, conn, database, table_name):
182182
for k, v in x.items() if k not in fields_to_drop}
183183
for x in attributes]
184184

185+
# exclude attributes that begin with an underscore: these are reserved for DataJoint internals
186+
attributes = [a for a in attributes if not a['name'].startswith('_')]
187+
185188
numeric_types = {
186189
('float', False): np.float64,
187190
('float', True): np.float64,

datajoint/jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import platform
44
from .table import Table
5-
from .errors import DuplicateError
5+
from .errors import DuplicateError, IntegrityError
66

77
ERROR_MESSAGE_LENGTH = 2047
88
TRUNCATION_APPENDIX = '...truncated'

0 commit comments

Comments
 (0)