Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Orange/preprocess/tests/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,48 @@ def test_lookup(self):
self.assertNotEqual(t1, t1a)
self.assertNotEqual(hash(t1), hash(t1a))

def test_safe_lookup_table_equal(self):
sl = Lookup._safe_lookup_table_equal
self.assertTrue(sl([1, 2, 3], [1, 2, 3]))
self.assertTrue(sl(np.array([1, 2, 3]), np.array([1, 2, 3])))
self.assertTrue(sl(np.array("foo bar baz".split()), np.array("foo bar baz".split())))

self.assertFalse(sl([1, 2, 3], [1, 2, 4]))
self.assertFalse(sl(np.array([1, 2, 3]), np.array([1, 2, 4])))
self.assertFalse(sl(np.array("foo bar baz".split()), np.array("foo bar qux".split())))
self.assertFalse(sl([1, 2, 3], [1, 2, 3, 4]))
self.assertFalse(sl(np.array([1, 2, 3]), np.array([1, 2, 3, 4])))
self.assertFalse(sl(np.array("foo bar baz".split()), np.array("foo bar baz qux".split())))

self.assertFalse(sl(np.array([1, 2, 3]), np.array("foo bar baz".split())))

def test_eq(self):
self.assertEqual(
Lookup(self.disc1, np.array([0, 2, 1]), 1),
Lookup(self.disc1a, np.array([0, 2, 1]), 1))
self.assertEqual(
Lookup(self.disc1, np.array(["foo", "bar", "baz"]), ""),
Lookup(self.disc1a, np.array(["foo", "bar", "baz"]), ""))

self.assertNotEqual(
Lookup(self.disc1, np.array([0, 2, 1]), 1),
Lookup(self.disc1a, np.array([0, 1, 2]), 1))
self.assertNotEqual(
Lookup(self.disc1, np.array([0, 2, 1]), 1),
Lookup(self.disc1a, np.array([0, 2, 1, 5]), 1))
self.assertNotEqual(
Lookup(self.disc1, np.array([0, 2, 1]), 1),
Lookup(self.disc1a, np.array([0, 2, 1]), 2))
self.assertNotEqual(
Lookup(self.disc1, np.array(["foo", "bar", "baz"]), ""),
Lookup(self.disc1a, np.array(["foo", "baz", "bar"]), ""))
self.assertNotEqual(
Lookup(self.disc1, np.array(["foo", "bar", "baz"]), ""),
Lookup(self.disc1a, np.array(["foo", "bar", "baz"]), "qux"))
self.assertNotEqual(
Lookup(self.disc1, np.array([0, 1, 2]), 1),
Lookup(self.disc1a, np.array(["foo", "bar", "baz"]), "qux"))

def test_mapping(self):
def test_equal(a, b):
self.assertEqual(a, b)
Expand Down
29 changes: 22 additions & 7 deletions Orange/preprocess/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,28 @@ def __init__(self, variable, lookup_table, unknown=np.nan):
:type variable: int or str or :obj:`~Orange.data.DiscreteVariable`
:param lookup_table: transformations for each value of `self.variable`
:type lookup_table: np.array
:param unknown: The value to be used as unknown value.
:type unknown: float or int
:param unknown : The value to be used as unknown value.
:type unknown: float or int or str
"""
super().__init__(variable)
self.lookup_table = lookup_table
self.unknown = unknown

@staticmethod
def _safe_lookup_table_equal(a, b):
a = np.asarray(a)
b = np.asarray(b)

if a.shape != b.shape:
return False

a_is_num = np.issubdtype(a.dtype, np.number)
b_is_num = np.issubdtype(b.dtype, np.number)
if a_is_num and b_is_num:
return np.allclose(a, b, equal_nan=True)

return np.array_equal(a, b)

def transform(self, column):
# Densify DiscreteVariable values coming from sparse datasets.
if sp.issparse(column):
Expand All @@ -236,11 +251,11 @@ def transform(self, column):
return np.where(mask, self.unknown, values)

def __eq__(self, other):
return super().__eq__(other) \
and np.allclose(self.lookup_table, other.lookup_table,
equal_nan=True) \
and np.allclose(self.unknown, other.unknown, equal_nan=True)

return (
super().__eq__(other)
and self._safe_lookup_table_equal(self.lookup_table, other.lookup_table)
and self._safe_lookup_table_equal(self.unknown, other.unknown)
)
def __hash__(self):
return hash(
(
Expand Down
Loading