Skip to content

Commit d315590

Browse files
sandervh14Sander Vanden Hautte
andauthored
Closing issue #33. (#51)
* Preventing floating point rounding issues on verifying equality of train_prop + selection_prop + validation_prop against 1.0. * Merging new unit test in the existing one. * My initiative to slightly alter the error message broke a unit test. Oopsie. Fixed. Co-authored-by: Sander Vanden Hautte <sander.vandenhautte@tobania.be>
1 parent 12de11c commit d315590

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

cobra/preprocessing/preprocessor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
from datetime import datetime
1616
import time
17+
import math
1718
import logging
1819
from random import shuffle
1920

@@ -361,9 +362,9 @@ def train_selection_validation_split(data: pd.DataFrame,
361362
pd.DataFrame
362363
DataFrame with additional split column
363364
"""
364-
if train_prop + selection_prop + validation_prop != 1.0:
365+
if not math.isclose(train_prop + selection_prop + validation_prop, 1.0):
365366
raise ValueError("The sum of train_prop, selection_prop and "
366-
"validation_prop cannot differ from 1.0")
367+
"validation_prop must be 1.0.")
367368

368369
if train_prop == 0.0:
369370
raise ValueError("train_prop cannot be zero!")

tests/preprocessing/test_preprocessor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@ def does_not_raise():
1616

1717
class TestPreProcessor:
1818

19-
@pytest.mark.parametrize(("train_prop, selection_prop, "
20-
"validation_prop, expected_sizes"),
19+
@pytest.mark.parametrize("train_prop, selection_prop, validation_prop, "
20+
"expected_sizes",
2121
[(0.6, 0.2, 0.2, {"train": 6,
2222
"selection": 2,
2323
"validation": 2}),
2424
(0.7, 0.3, 0.0, {"train": 7,
25-
"selection": 3})])
25+
"selection": 3}),
26+
# Error "The sum of train_prop, selection_prop and
27+
# validation_prop must be 1.0." should not be
28+
# raised:
29+
(0.7, 0.2, 0.1, {"train": 7,
30+
"selection": 2,
31+
"validation": 1})])
2632
def test_train_selection_validation_split(self, train_prop: float,
2733
selection_prop: float,
2834
validation_prop: float,
@@ -50,7 +56,7 @@ def test_train_selection_validation_split(self, train_prop: float,
5056
def test_train_selection_validation_split_error_wrong_prop(self):
5157

5258
error_msg = ("The sum of train_prop, selection_prop and "
53-
"validation_prop cannot differ from 1.0")
59+
"validation_prop must be 1.0.")
5460
train_prop = 0.7
5561
selection_prop = 0.3
5662

0 commit comments

Comments
 (0)