Skip to content

Commit 8820580

Browse files
author
Patrick Leonardy
committed
Add new line at end of file, Format with black
1 parent a9c21ca commit 8820580

1 file changed

Lines changed: 36 additions & 22 deletions

File tree

tests/preprocessing/test_pig_tables.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,52 @@
55

66
from typing import Optional
77

8-
class TestPigTablesGeneration:
98

10-
@pytest.mark.parametrize("id_col_name", [None, "col_id"]) # test None as this is the default value in generate pig tabels
9+
class TestPigTablesGeneration:
10+
@pytest.mark.parametrize(
11+
"id_col_name", [None, "col_id"]
12+
) # test None as this is the default value in generate pig tabels
1113
def test_col_id(self, id_col_name: Optional[str]):
12-
14+
1315
# input
14-
data = pd.DataFrame({
15-
'col_id': [0, 1, 3, 4, 6],
16-
'survived': [0, 1, 1, 0, 0],
17-
'pclass': [3, 1, 1, 3, 1],
18-
'sex': ['male', 'female', 'female', 'male', 'male'],
19-
'age': [22.0, 38.0, 35.0, 35.0, 54.0]
20-
})
16+
data = pd.DataFrame(
17+
{
18+
"col_id": [0, 1, 3, 4, 6],
19+
"survived": [0, 1, 1, 0, 0],
20+
"pclass": [3, 1, 1, 3, 1],
21+
"sex": ["male", "female", "female", "male", "male"],
22+
"age": [22.0, 38.0, 35.0, 35.0, 54.0],
23+
}
24+
)
2125
target = "survived"
2226
prep_col = ["pclass", "sex", "age"]
23-
27+
2428
# output
2529
out = generate_pig_tables(
26-
basetable= data,
30+
basetable=data,
2731
target_column_name=target,
2832
preprocessed_predictors=prep_col,
29-
id_column_name=id_col_name
33+
id_column_name=id_col_name,
3034
)
31-
35+
3236
# expected
33-
expected = pd.DataFrame({
34-
'variable': ['age', 'age', 'age', 'age', 'pclass', 'pclass', 'sex', 'sex'],
35-
'label': [22.0, 35.0, 38.0, 54.0, 1, 3, 'female', 'male'],
36-
'pop_size': [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6],
37-
'global_avg_target': [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
38-
'avg_target': [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0]
39-
})
37+
expected = pd.DataFrame(
38+
{
39+
"variable": [
40+
"age",
41+
"age",
42+
"age",
43+
"age",
44+
"pclass",
45+
"pclass",
46+
"sex",
47+
"sex",
48+
],
49+
"label": [22.0, 35.0, 38.0, 54.0, 1, 3, "female", "male"],
50+
"pop_size": [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6],
51+
"global_avg_target": [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
52+
"avg_target": [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0],
53+
}
54+
)
4055

4156
pd.testing.assert_frame_equal(out, expected)
42-

0 commit comments

Comments
 (0)