Skip to content

Commit 7060d9e

Browse files
Merge pull request #573 from dimitri-yatsenko/alter2
Altering table definition (Fix #110)
2 parents e8e900c + 07b02f2 commit 7060d9e

File tree

6 files changed

+199
-24
lines changed

6 files changed

+199
-24
lines changed

datajoint/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ def is_connected(self):
137137
"""
138138
try:
139139
self.ping()
140-
return True
141140
except:
142141
return False
142+
return True
143143

144144
def query(self, query, args=(), as_dict=False, suppress_warnings=True, reconnect=None):
145145
"""

datajoint/declare.py

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import re
66
import pyparsing as pp
77
import logging
8-
98
from .errors import DataJointError
109

10+
from .utils import OrderedDict
11+
1112
UUID_DATA_TYPE = 'binary(16)'
1213
MAX_TABLE_NAME_LENGTH = 64
1314
CONSTANT_LITERALS = {'CURRENT_TIMESTAMP'} # SQL literals to be used without quotes (case insensitive)
@@ -218,20 +219,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
218219
index_sql.append('UNIQUE INDEX ({attrs})'.format(attrs='`,`'.join(ref.primary_key)))
219220

220221

221-
def declare(full_table_name, definition, context):
222-
"""
223-
Parse declaration and create new SQL table accordingly.
224-
225-
:param full_table_name: full name of the table
226-
:param definition: DataJoint table definition
227-
:param context: dictionary of objects that might be referred to in the table.
228-
"""
229-
table_name = full_table_name.strip('`').split('.')[1]
230-
if len(table_name) > MAX_TABLE_NAME_LENGTH:
231-
raise DataJointError(
232-
'Table name `{name}` exceeds the max length of {max_length}'.format(
233-
name=table_name,
234-
max_length=MAX_TABLE_NAME_LENGTH))
222+
def prepare_declare(definition, context):
235223
# split definition into lines
236224
definition = re.split(r'\s*\n\s*', definition.strip())
237225
# check for optional table comment
@@ -266,7 +254,28 @@ def declare(full_table_name, definition, context):
266254
if name not in attributes:
267255
attributes.append(name)
268256
attribute_sql.append(sql)
269-
# compile SQL
257+
258+
return table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores
259+
260+
261+
def declare(full_table_name, definition, context):
262+
"""
263+
Parse declaration and generate the SQL CREATE TABLE code
264+
:param full_table_name: full name of the table
265+
:param definition: DataJoint table definition
266+
:param context: dictionary of objects that might be referred to in the table
267+
:return: SQL CREATE TABLE statement, list of external stores used
268+
"""
269+
table_name = full_table_name.strip('`').split('.')[1]
270+
if len(table_name) > MAX_TABLE_NAME_LENGTH:
271+
raise DataJointError(
272+
'Table name `{name}` exceeds the max length of {max_length}'.format(
273+
name=table_name,
274+
max_length=MAX_TABLE_NAME_LENGTH))
275+
276+
table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores = prepare_declare(
277+
definition, context)
278+
270279
if not primary_key:
271280
raise DataJointError('Table must have a primary key')
272281

@@ -276,6 +285,94 @@ def declare(full_table_name, definition, context):
276285
'\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment), external_stores
277286

278287

288+
def _make_attribute_alter(new, old, primary_key):
289+
"""
290+
:param new: new attribute declarations
291+
:param old: old attribute declarations
292+
:param primary_key: primary key attributes
293+
:return: list of SQL ALTER commands
294+
"""
295+
296+
# parse attribute names
297+
name_regexp = re.compile(r"^`(?P<name>\w+)`")
298+
original_regexp = re.compile(r'COMMENT "\{\s*(?P<name>\w+)\s*\}')
299+
matched = ((name_regexp.match(d), original_regexp.search(d)) for d in new)
300+
new_names = OrderedDict((d.group('name'), n and n.group('name')) for d, n in matched)
301+
old_names = [name_regexp.search(d).group('name') for d in old]
302+
303+
# verify that original names are only used once
304+
renamed = set()
305+
for v in new_names.values():
306+
if v:
307+
if v in renamed:
308+
raise DataJointError('Alter attempted to rename attribute {%s} twice.' % v)
309+
renamed.add(v)
310+
311+
# verify that all renamed attributes existed in the old definition
312+
try:
313+
raise DataJointError(
314+
"Attribute {} does not exist in the original definition".format(
315+
next(attr for attr in renamed if attr not in old_names)))
316+
except StopIteration:
317+
pass
318+
319+
# dropping attributes
320+
to_drop = [n for n in old_names if n not in renamed and n not in new_names]
321+
sql = ['DROP `%s`' % n for n in to_drop]
322+
old_names = [name for name in old_names if name not in to_drop]
323+
324+
# add or change attributes in order
325+
prev = None
326+
for new_def, (new_name, old_name) in zip(new, new_names.items()):
327+
if new_name not in primary_key:
328+
after = None # if None, then must include the AFTER clause
329+
if prev:
330+
try:
331+
idx = old_names.index(old_name or new_name)
332+
except ValueError:
333+
after = prev[0]
334+
else:
335+
if idx >= 1 and old_names[idx - 1] != (prev[1] or prev[0]):
336+
after = prev[0]
337+
if new_def not in old or after:
338+
sql.append('{command} {new_def} {after}'.format(
339+
command=("ADD" if (old_name or new_name) not in old_names else
340+
"MODIFY" if not old_name else
341+
"CHANGE `%s`" % old_name),
342+
new_def=new_def,
343+
after="" if after is None else "AFTER `%s`" % after))
344+
prev = new_name, old_name
345+
346+
return sql
347+
348+
349+
def alter(definition, old_definition, context):
350+
"""
351+
:param definition: new table definition
352+
:param old_definition: current table definition
353+
:param context: the context in which to evaluate foreign key definitions
354+
:return: string SQL ALTER command, list of new stores used for external storage
355+
"""
356+
table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores = prepare_declare(
357+
definition, context)
358+
table_comment_, primary_key_, attribute_sql_, foreign_key_sql_, index_sql_, external_stores_ = prepare_declare(
359+
old_definition, context)
360+
361+
# analyze differences between declarations
362+
sql = list()
363+
if primary_key != primary_key_:
364+
raise NotImplementedError('table.alter cannot alter the primary key (yet).')
365+
if foreign_key_sql != foreign_key_sql_:
366+
raise NotImplementedError('table.alter cannot alter foreign keys (yet).')
367+
if index_sql != index_sql_:
368+
raise NotImplementedError('table.alter cannot alter indexes (yet)')
369+
if attribute_sql != attribute_sql_:
370+
sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key))
371+
if table_comment != table_comment_:
372+
sql.append('COMMENT="%s"' % table_comment)
373+
return sql, [e for e in external_stores if e not in external_stores_]
374+
375+
279376
def compile_index(line, index_sql):
280377
match = index_parser.parseString(line)
281378
index_sql.append('{unique} index ({attrs})'.format(

datajoint/table.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import uuid
1010
from pymysql import OperationalError, InternalError, IntegrityError
1111
from .settings import config
12-
from .declare import declare
12+
from .declare import declare, alter
1313
from .expression import QueryExpression
1414
from . import attach, blob
1515
from .utils import user_choice
@@ -56,14 +56,15 @@ def heading(self):
5656

5757
def declare(self, context=None):
5858
"""
59-
Use self.definition to declare the table in the schema.
59+
Declare the table in the schema based on self.definition.
60+
:param context: the context for foreign key resolution. If None, foreign keys are not allowed.
6061
"""
6162
if self.connection.in_transaction:
6263
raise DataJointError('Cannot declare new tables inside a transaction, '
6364
'e.g. from inside a populate/make call')
65+
sql, external_stores = declare(self.full_table_name, self.definition, context)
66+
sql = sql.format(database=self.database)
6467
try:
65-
sql, external_stores = declare(self.full_table_name, self.definition, context)
66-
sql = sql.format(database=self.database)
6768
# declare all external tables before declaring main table
6869
for store in external_stores:
6970
self.connection.schemas[self.database].external[store]
@@ -77,6 +78,41 @@ def declare(self, context=None):
7778
else:
7879
self._log('Declared ' + self.full_table_name)
7980

81+
def alter(self, prompt=True, context=None):
82+
"""
83+
Alter the table definition from self.definition
84+
"""
85+
if self.connection.in_transaction:
86+
raise DataJointError('Cannot update table declaration inside a transaction, '
87+
'e.g. from inside a populate/make call')
88+
if context is None:
89+
frame = inspect.currentframe().f_back
90+
context = dict(frame.f_globals, **frame.f_locals)
91+
del frame
92+
old_definition = self.describe(context=context, printout=False)
93+
sql, external_stores = alter(self.definition, old_definition, context)
94+
if not sql:
95+
if prompt:
96+
print('Nothing to alter.')
97+
else:
98+
sql = "ALTER TABLE {tab}\n\t".format(tab=self.full_table_name) + ",\n\t".join(sql)
99+
if not prompt or user_choice(sql + '\n\nExecute?') == 'yes':
100+
try:
101+
# declare all external tables before declaring main table
102+
for store in external_stores:
103+
self.connection.schemas[self.database].external[store]
104+
self.connection.query(sql)
105+
except pymysql.OperationalError as error:
106+
# skip if no create privilege
107+
if error.args[0] == server_error_codes['command denied']:
108+
logger.warning(error.args[1])
109+
else:
110+
raise
111+
else:
112+
if prompt:
113+
print('Table altered')
114+
self._log('Altered ' + self.full_table_name)
115+
80116
@property
81117
def from_clause(self):
82118
"""
@@ -544,7 +580,7 @@ def _update(self, attrname, value=None):
544580
545581
Example
546582
547-
>>> (v2p.Mice() & key).update('mouse_dob', '2011-01-01')
583+
>>> (v2p.Mice() & key).update('mouse_dob', '2011-01-01')
548584
>>> (v2p.Mice() & key).update( 'lens') # set the value to NULL
549585
550586
"""

datajoint/user_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# attributes that trigger instantiation of user classes
1414
supported_class_attrs = {
15-
'key_source', 'describe', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
15+
'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
1616
'fetch', 'fetch1','head', 'tail',
1717
'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
1818

tests/schema_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def fill(self):
5050
def make_parent(pid, parent):
5151
return dict(person_id=pid,
5252
parent=parent,
53-
parent_sex=(Person() & dict(person_id=parent)).fetch('sex')[0])
53+
parent_sex=(Person & {'person_id': parent}).fetch1('sex'))
5454

5555
self.insert(make_parent(*r) for r in (
5656
(0, 2), (0, 3), (1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 7), (4, 7),

tests/test_alter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from nose.tools import assert_equal, assert_not_equal
2+
3+
from .schema import *
4+
5+
6+
@schema
7+
class Experiment(dj.Imported):
8+
9+
original_definition = """ # information about experiments
10+
-> Subject
11+
experiment_id :smallint # experiment number for this subject
12+
---
13+
experiment_date :date # date when experiment was started
14+
-> [nullable] User
15+
data_path="" :varchar(255) # file path to recorded data
16+
notes="" :varchar(2048) # e.g. purpose of experiment
17+
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
18+
"""
19+
20+
definition1 = """ # Experiment
21+
-> Subject
22+
experiment_id :smallint # experiment number for this subject
23+
---
24+
data_path : int # some number
25+
extra=null : longblob # just testing
26+
-> [nullable] User
27+
subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment
28+
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
29+
"""
30+
31+
32+
def test_alter():
33+
34+
original = schema.connection.query("SHOW CREATE TABLE " + Experiment.full_table_name).fetchone()[1]
35+
Experiment.definition = Experiment.definition1
36+
Experiment.alter(prompt=False)
37+
altered = schema.connection.query("SHOW CREATE TABLE " + Experiment.full_table_name).fetchone()[1]
38+
Experiment.definition = Experiment.original_definition
39+
Experiment().alter(prompt=False)
40+
restored = schema.connection.query("SHOW CREATE TABLE " + Experiment.full_table_name).fetchone()[1]
41+
assert_equal(original, restored)
42+
assert_not_equal(original, altered)

0 commit comments

Comments
 (0)