Skip to content

Commit 272531b

Browse files
Merge pull request #651 from dimitri-yatsenko/refactor-errors
Refactor error handling and fix schema diagrams
2 parents 74dfe41 + dcac226 commit 272531b

File tree

14 files changed

+190
-125
lines changed

14 files changed

+190
-125
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
### 0.11.3 -- Jul 26, 2019
2525
* Fix incompatibility with pyparsing 2.4.1 (#629) PR #631
2626

27+
### 0.11.2 -- July 25, 2019
28+
* Fix #628 - incompatibility with pyparsing 2.4.1
29+
2730
### 0.11.1 -- Nov 15, 2018
2831
* Fix ordering of attributes in proj (#483 and #516)
2932
* Prohibit direct insert into auto-populated tables (#511)

datajoint/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
'Not', 'AndList', 'U', 'Diagram', 'Di', 'ERD',
2525
'set_password', 'kill',
2626
'MatCell', 'MatStruct',
27-
'DataJointError', 'DuplicateError', 'key']
28-
27+
'errors', 'DataJointError', 'key']
2928

3029
from .version import __version__
3130
from .settings import config
@@ -38,7 +37,8 @@
3837
from .diagram import Diagram
3938
from .admin import set_password, kill
4039
from .blob import MatCell, MatStruct
41-
from .errors import DataJointError, DuplicateError
4240
from .fetch import key
41+
from . import errors
42+
from .errors import DataJointError
4343

4444
ERD = Di = Diagram # Aliases for Diagram

datajoint/connection.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,48 @@
77
import pymysql as client
88
import logging
99
from getpass import getpass
10-
from pymysql import err
1110

1211
from .settings import config
13-
from .errors import DataJointError, server_error_codes, is_connection_error
12+
from . import errors
1413
from .dependencies import Dependencies
1514

15+
# client errors to catch
16+
client_errors = (client.err.InterfaceError, client.err.DatabaseError)
17+
18+
19+
def translate_query_error(client_error, query):
20+
"""
21+
Take client error and original query and return the corresponding DataJoint exception.
22+
:param client_error: the exception raised by the client interface
23+
:param query: sql query with placeholders
24+
:return: an instance of the corresponding subclass of datajoint.errors.DataJointError
25+
"""
26+
# Loss of connection errors
27+
if isinstance(client_error, client.err.InterfaceError) and client_error.args[0] == "(0, '')":
28+
return errors.LostConnectionError('Server connection lost due to an interface error.', *client_error.args[1:])
29+
disconnect_codes = {
30+
2006: "Connection timed out",
31+
2013: "Server connection lost"}
32+
if isinstance(client_error, client.err.OperationalError) and client_error.args[0] in disconnect_codes:
33+
return errors.LostConnectionError(disconnect_codes[client_error.args[0]], *client_error.args[1:])
34+
# Access errors
35+
if isinstance(client_error, client.err.OperationalError) and client_error.args[0] in (1044, 1142):
36+
return errors.AccessError('Insufficient privileges.', client_error.args[1], query)
37+
# Integrity errors
38+
if isinstance(client_error, client.err.IntegrityError) and client_error.args[0] == 1062:
39+
return errors.DuplicateError(*client_error.args[1:])
40+
if isinstance(client_error, client.err.IntegrityError) and client_error.args[0] == 1452:
41+
return errors.IntegrityError(*client_error.args[1:])
42+
# Syntax Errors
43+
if isinstance(client_error, client.err.ProgrammingError) and client_error.args[0] == 1064:
44+
return errors.QuerySyntaxError(client_error.args[1], query)
45+
# Existence Errors
46+
if isinstance(client_error, client.err.ProgrammingError) and client_error.args[0] == 1146:
47+
return errors.MissingTableError(client_error.args[1], query)
48+
if isinstance(client_error, client.err.InternalError) and client_error.args[0] == 1364:
49+
return errors.MissingAttributeError(*client_error.args[1:])
50+
raise client_error
51+
1652

1753
logger = logging.getLogger(__name__)
1854

@@ -60,6 +96,7 @@ class Connection:
6096
:param init_fun: connection initialization function (SQL)
6197
:param use_tls: TLS encryption option
6298
"""
99+
63100
def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None):
64101
if ':' in host:
65102
# the port in the hostname overrides the port argument
@@ -79,7 +116,7 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
79116
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
80117
self.connection_id = self.query('SELECT connection_id()').fetchone()[0]
81118
else:
82-
raise DataJointError('Connection failed.')
119+
raise errors.ConnectionError('Connection failed.')
83120
self._in_transaction = False
84121
self.schemas = dict()
85122
self.dependencies = Dependencies(self)
@@ -103,16 +140,16 @@ def connect(self):
103140
self._conn = client.connect(
104141
init_command=self.init_fun,
105142
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
106-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
143+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
107144
charset=config['connection.charset'],
108145
**self.conn_info)
109-
except err.InternalError:
146+
except client.err.InternalError:
110147
if ssl_input is None:
111148
self.conn_info.pop('ssl')
112149
self._conn = client.connect(
113150
init_command=self.init_fun,
114151
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
115-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
152+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
116153
charset=config['connection.charset'],
117154
**self.conn_info)
118155
self.conn_info['ssl_input'] = ssl_input
@@ -141,50 +178,46 @@ def is_connected(self):
141178
return False
142179
return True
143180

144-
def query(self, query, args=(), as_dict=False, suppress_warnings=True, reconnect=None):
181+
@staticmethod
182+
def __execute_query(cursor, query, args, cursor_class, suppress_warnings):
183+
try:
184+
with warnings.catch_warnings():
185+
if suppress_warnings:
186+
# suppress all warnings arising from underlying SQL library
187+
warnings.simplefilter("ignore")
188+
cursor.execute(query, args)
189+
except client_errors as err:
190+
raise translate_query_error(err, query)
191+
192+
def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None):
145193
"""
146194
Execute the specified query and return the tuple generator (cursor).
147-
148-
:param query: mysql query
195+
:param query: SQL query
149196
:param args: additional arguments for the client.cursor
150197
:param as_dict: If as_dict is set to True, the returned cursor objects returns
151198
query results as dictionary.
152199
:param suppress_warnings: If True, suppress all warnings arising from underlying query library
200+
:param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
153201
"""
154202
if reconnect is None:
155203
reconnect = config['database.reconnect']
156-
157-
cursor = client.cursors.DictCursor if as_dict else client.cursors.Cursor
158-
cur = self._conn.cursor(cursor=cursor)
159-
160204
logger.debug("Executing SQL:" + query[0:300])
205+
cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
206+
cursor = self._conn.cursor(cursor=cursor_class)
161207
try:
162-
with warnings.catch_warnings():
163-
if suppress_warnings:
164-
# suppress all warnings arising from underlying SQL library
165-
warnings.simplefilter("ignore")
166-
cur.execute(query, args)
167-
except (err.InterfaceError, err.OperationalError) as e:
168-
if is_connection_error(e) and reconnect:
169-
warnings.warn("Mysql server has gone away. Reconnecting to the server.")
170-
self.connect()
171-
if self._in_transaction:
172-
self.cancel_transaction()
173-
raise DataJointError("Connection was lost during a transaction.") from None
174-
else:
175-
logger.debug("Re-executing SQL")
176-
cur = self.query(query, args=args, as_dict=as_dict, suppress_warnings=suppress_warnings, reconnect=False)
177-
else:
178-
logger.debug("Caught InterfaceError/OperationalError.")
208+
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
209+
except errors.LostConnectionError:
210+
if not reconnect:
179211
raise
180-
except err.ProgrammingError as e:
181-
if e.args[0] == server_error_codes['parse error']:
182-
raise DataJointError("\n".join((
183-
"Error in query:", query,
184-
"Please check spelling, syntax, and existence of tables and attributes.",
185-
"When restricting a relation by a condition in a string, enclose attributes in backquotes."
186-
))) from None
187-
return cur
212+
warnings.warn("MySQL server has gone away. Reconnecting to the server.")
213+
self.connect()
214+
if self._in_transaction:
215+
self.cancel_transaction()
216+
raise errors.LostConnectionError("Connection was lost during a transaction.") from None
217+
logger.debug("Re-executing")
218+
cursor = self._conn.cursor(cursor=cursor_class)
219+
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
220+
return cursor
188221

189222
def get_user(self):
190223
"""
@@ -204,11 +237,9 @@ def in_transaction(self):
204237
def start_transaction(self):
205238
"""
206239
Starts a transaction error.
207-
208-
:raise DataJointError: if there is an ongoing transaction.
209240
"""
210241
if self.in_transaction:
211-
raise DataJointError("Nested connections are not supported.")
242+
raise errors.DataJointError("Nested connections are not supported.")
212243
self.query('START TRANSACTION WITH CONSISTENT SNAPSHOT')
213244
self._in_transaction = True
214245
logger.info("Transaction started")
@@ -252,3 +283,4 @@ def transaction(self):
252283
raise
253284
else:
254285
self.commit_transaction()
286+

datajoint/declare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def prepare_declare(definition, context):
235235
external_stores = []
236236

237237
for line in definition:
238-
if line.startswith('#'): # ignore additional comments
238+
if not line or line.startswith('#'): # ignore additional comments
239239
pass
240240
elif line.startswith('---') or line.startswith('___'):
241241
in_key = False # start parsing dependent attributes

datajoint/dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def load(self):
6262
# add edges to the graph
6363
for fk in fks.values():
6464
props = dict(
65-
primary=all(attr in pks[fk['referencing_table']] for attr in fk['attr_map']),
65+
primary=set(fk['attr_map']) <= set(pks[fk['referencing_table']]),
6666
attr_map=fk['attr_map'],
6767
aliased=any(k != v for k, v in fk['attr_map'].items()),
68-
multi=not all(a in fk['attr_map'] for a in pks[fk['referencing_table']]))
68+
multi=set(fk['attr_map']) != set(pks[fk['referencing_table']]))
6969
if not props['aliased']:
7070
self.add_edge(fk['referenced_table'], fk['referencing_table'], **props)
7171
else:

datajoint/diagram.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def __add__(self, arg):
191191
new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)
192192
if not new:
193193
break
194+
# add nodes referenced by aliased nodes
195+
new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit())))
194196
self.nodes_to_show.update(new)
195197
return self
196198

@@ -207,9 +209,12 @@ def __sub__(self, arg):
207209
self.nodes_to_show.remove(arg.full_table_name)
208210
except AttributeError:
209211
for i in range(arg):
210-
new = nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)
212+
graph = nx.DiGraph(self).reverse()
213+
new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show)
211214
if not new:
212215
break
216+
# add nodes referenced by aliased nodes
217+
new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit())))
213218
self.nodes_to_show.update(new)
214219
return self
215220

@@ -229,9 +234,10 @@ def _make_graph(self):
229234
"""
230235
# mark "distinguished" tables, i.e. those that introduce new primary key attributes
231236
for name in self.nodes_to_show:
232-
foreign_attributes = set(attr for p in self.in_edges(name, data=True) for attr in p[2]['attr_map'])
233-
self.node[name]['distinguished'] = ('primary_key' in self.node[name] and
234-
foreign_attributes < self.node[name]['primary_key'])
237+
foreign_attributes = set(
238+
attr for p in self.in_edges(name, data=True) for attr in p[2]['attr_map'] if p[2]['primary'])
239+
self.node[name]['distinguished'] = (
240+
'primary_key' in self.node[name] and foreign_attributes < self.node[name]['primary_key'])
235241
# include aliased nodes that are sandwiched between two displayed nodes
236242
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
237243
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show))

datajoint/errors.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,76 @@
1-
from pymysql import err
1+
"""
2+
Exception classes for the DataJoint library
3+
"""
24

3-
server_error_codes = {
4-
'unknown column': 1054,
5-
'duplicate entry': 1062,
6-
'parse error': 1064,
7-
'command denied': 1142,
8-
'table does not exist': 1146,
9-
'syntax error': 1149,
10-
}
115

12-
operation_error_codes = {
13-
'connection timedout': 2006,
14-
'lost connection': 2013,
15-
}
6+
# --- Top Level ---
7+
class DataJointError(Exception):
8+
"""
9+
Base class for errors specific to DataJoint internal operation.
10+
"""
11+
def suggest(self, *args):
12+
"""
13+
regenerate the exception with additional arguments
14+
:param args: addition arguments
15+
:return: a new exception of the same type with the additional arguments
16+
"""
17+
return self.__class__(*(self.args + args))
1618

1719

18-
def is_connection_error(e):
20+
# --- Second Level ---
21+
class LostConnectionError(DataJointError):
1922
"""
20-
Checks if error e pertains to a connection issue
23+
Loss of server connection
2124
"""
22-
return (isinstance(e, err.InterfaceError) and e.args[0] == "(0, '')") or\
23-
(isinstance(e, err.OperationalError) and e.args[0] in operation_error_codes.values())
2425

2526

26-
class DataJointError(Exception):
27+
class QueryError(DataJointError):
2728
"""
28-
Base class for errors specific to DataJoint internal operation.
29+
Errors arising from queries to the database
30+
"""
31+
32+
33+
# --- Third Level: QueryErrors ---
34+
class QuerySyntaxError(QueryError):
35+
"""
36+
Errors arising from incorrect query syntax
37+
"""
38+
39+
40+
class AccessError(QueryError):
41+
"""
42+
User access error: insufficient privileges.
43+
"""
44+
45+
46+
class UnknownAttributeError(DataJointError):
47+
"""
48+
Error caused by referencing to a non-existing attributes
49+
"""
50+
51+
52+
class MissingTableError(DataJointError):
53+
"""
54+
Query on a table that has not been declared
55+
"""
56+
57+
58+
class DuplicateError(QueryError):
59+
"""
60+
An integrity error caused by a duplicate entry into a unique key
61+
"""
62+
63+
64+
class IntegrityError(QueryError):
65+
"""
66+
An integrity error triggered by foreign key constraints
2967
"""
30-
pass
3168

3269

33-
class DuplicateError(DataJointError):
70+
class MissingAttributeError(QueryError):
3471
"""
35-
Error caused by a violation of a unique constraint when inserting data
72+
An error arising when a required attribute value is not provided in INSERT
3673
"""
37-
pass
3874

3975

4076
class MissingExternalFile(DataJointError):

datajoint/fetch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None,
121121
:param download_path: for fetches that download data, e.g. attachments
122122
:return: the contents of the relation in the form of a structured numpy.array or a dict list
123123
"""
124-
125124
if order_by is not None:
126125
# if 'order_by' passed in a string, make into list
127126
if isinstance(order_by, str):

datajoint/jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import platform
44
from .table import Table
5-
from .errors import DuplicateError
5+
from .errors import DuplicateError, IntegrityError
66

77
ERROR_MESSAGE_LENGTH = 2047
88
TRUNCATION_APPENDIX = '...truncated'

0 commit comments

Comments
 (0)