Skip to content

Commit 09fe35b

Browse files
committed
Fix mapping of multidimensional columns
1 parent e7c832a commit 09fe35b

2 files changed

Lines changed: 35 additions & 9 deletions

File tree

datamatrix/_datamatrix/_multidimensionalcolumn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
You should have received a copy of the GNU General Public License
1717
along with datamatrix. If not, see <http://www.gnu.org/licenses/>.
1818
"""
19+
from datamatrix.py3compat import *
1920
import logging
20-
import os
2121
import weakref
22-
from datamatrix.py3compat import *
2322
from datamatrix import cfg
24-
from datamatrix._datamatrix._numericcolumn import NumericColumn, FloatColumn
23+
from datamatrix._datamatrix._mixedcolumn import MixedColumn
24+
from datamatrix._datamatrix._numericcolumn import NumericColumn, FloatColumn, \
25+
IntColumn
2526
from datamatrix._datamatrix._datamatrix import DataMatrix
2627
from collections.abc import Sequence, Collection
2728
from collections import OrderedDict
@@ -348,11 +349,19 @@ def _map(self, fnc):
348349

349350
# For a MultiDimensionalColumn, we need to make a special case, because
350351
# the shape of the new MultiDimensionalColumn may be different from
351-
# the shape of the original column.
352+
# the shape of the original column. The new column may even be a
353+
# different kind of column altogether.
352354
for i, cell in enumerate(self):
353355
a = fnc(cell)
354356
if not i:
355-
newcol = self.__class__(self.dm, shape=len(a))
357+
if isinstance(a, float):
358+
newcol = FloatColumn(self.dm)
359+
elif isinstance(a, int):
360+
newcol = IntColumn(self.dm)
361+
elif isinstance(a, str):
362+
newcol = MixedColumn(self.dm)
363+
else:
364+
newcol = self.__class__(self.dm, shape=len(a))
356365
newcol[i] = a
357366
return newcol
358367

testcases/test_functional.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,42 @@
1818
"""
1919

2020
from datamatrix.py3compat import *
21-
from datamatrix import DataMatrix, MixedColumn, IntColumn, FloatColumn
21+
from datamatrix import DataMatrix, MixedColumn, IntColumn, FloatColumn, \
22+
MultiDimensionalColumn
2223
from datamatrix import functional as fnc
23-
from testcases.test_tools import capture_stdout
24+
from testcases.test_tools import capture_stdout, check_series
2425

2526

26-
def test_map_():
27+
def test_map():
2728

2829
for coltype in (MixedColumn, FloatColumn, IntColumn):
2930
dm = DataMatrix(length=2, default_col_type=coltype)
3031
dm.a = 1, 2
3132
dm.a = fnc.map_(lambda x: x*2, dm.a)
3233
assert dm.a == [2, 4]
3334
assert isinstance(dm.a, coltype)
35+
dm.a = 1, 2
36+
dm.a = dm.a @ (lambda x: x*2)
37+
assert dm.a == [2, 4]
38+
assert isinstance(dm.a, coltype)
3439
dm = fnc.map_(lambda **d: {'a' : 0}, dm)
3540
assert dm.a == [0, 0]
3641
assert isinstance(dm.a, coltype)
3742

3843

39-
def test_filter_():
44+
def test_map_multidimensional():
45+
dm = DataMatrix(length=2)
46+
dm.m = MultiDimensionalColumn(shape=(3,))
47+
dm.m = [[1,2,3], [4,5,6]]
48+
dm.mean = dm.m @ (lambda a: a.mean())
49+
assert dm.mean == [2, 5]
50+
dm.half = dm.m @ (lambda a: a / 2)
51+
check_series(dm.half, [[0.5, 1., 1.5], [2, 2.5, 3.]])
52+
dm.short = dm.m @ (lambda a: a[:2])
53+
check_series(dm.short, [[1, 2], [4, 5]])
54+
55+
56+
def test_filter():
4057

4158
dm = DataMatrix(length=4)
4259
dm.a = range(4)

0 commit comments

Comments
 (0)