Skip to content

Commit 8f770c7

Browse files
Merge pull request #141 from PythonPredictions/#135-Remove-mandatory-id-column
Defaults id_column to None for PIGs & tests
2 parents 03372ea + 8820580 commit 8f770c7

2 files changed

Lines changed: 83 additions & 17 deletions

File tree

cobra/evaluation/pigs_tables.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import cobra.utils as utils
99

1010
def generate_pig_tables(basetable: pd.DataFrame,
11-
id_column_name: str,
1211
target_column_name: str,
13-
preprocessed_predictors: list) -> pd.DataFrame:
12+
preprocessed_predictors: list,
13+
id_column_name: str = None) -> pd.DataFrame:
1414
"""Compute PIG tables for all predictors in preprocessed_predictors.
1515
1616
The output is a DataFrame with columns ``variable``, ``label``,
@@ -20,35 +20,41 @@ def generate_pig_tables(basetable: pd.DataFrame,
2020
----------
2121
basetable : pd.DataFrame
2222
Basetable to compute PIG tables from.
23-
id_column_name : str
24-
Name of the basetable column containing the IDs of the basetable rows
25-
(e.g. customernumber).
2623
target_column_name : str
2724
Name of the basetable column containing the target values to predict.
2825
preprocessed_predictors: list
2926
List of basetable column names containing preprocessed predictors.
30-
27+
id_column_name : str, default=None
28+
Name of the basetable column containing the IDs of the basetable rows
29+
(e.g. customernumber).
3130
Returns
3231
-------
3332
pd.DataFrame
3433
DataFrame containing a PIG table for all predictors.
3534
"""
35+
36+
#check if there is a id-column and define no_predictor accordingly
37+
if id_column_name == None:
38+
no_predictor = [target_column_name]
39+
else:
40+
no_predictor = [id_column_name, target_column_name]
41+
42+
3643
pigs = [
3744
compute_pig_table(basetable,
3845
column_name,
3946
target_column_name,
40-
id_column_name)
47+
)
4148
for column_name in sorted(preprocessed_predictors)
42-
if column_name not in [id_column_name, target_column_name]
49+
if column_name not in no_predictor
4350
]
44-
output = pd.concat(pigs)
51+
output = pd.concat(pigs, ignore_index=True)
4552
return output
4653

4754

4855
def compute_pig_table(basetable: pd.DataFrame,
4956
predictor_column_name: str,
50-
target_column_name: str,
51-
id_column_name: str) -> pd.DataFrame:
57+
target_column_name: str) -> pd.DataFrame:
5258
"""Compute the PIG table of a given predictor for a given target.
5359
5460
Parameters
@@ -59,8 +65,6 @@ def compute_pig_table(basetable: pd.DataFrame,
5965
Predictor name of which to compute the pig table.
6066
target_column_name : str
6167
Name of the target variable.
62-
id_column_name : str
63-
Name of the id column (used to count population size).
6468
6569
Returns
6670
-------
@@ -72,12 +76,18 @@ def compute_pig_table(basetable: pd.DataFrame,
7276
# group by the binned variable, compute the incidence
7377
# (= mean of the target for the given bin) and compute the bin size
7478
# (e.g. COUNT(id_column_name)). After that, rename the columns
79+
7580
res = (basetable.groupby(predictor_column_name)
76-
.agg({target_column_name: "mean", id_column_name: "size"})
81+
.agg(
82+
avg_target = (target_column_name, "mean"),
83+
pop_size = (target_column_name, "size")
84+
)
7785
.reset_index()
78-
.rename(columns={predictor_column_name: "label",
79-
target_column_name: "avg_target",
80-
id_column_name: "pop_size"}))
86+
.rename(
87+
columns={predictor_column_name: "label"}
88+
)
89+
)
90+
8191

8292
# add the column name to a variable column
8393
# add the average incidence
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
import pandas as pd
4+
from cobra.evaluation.pigs_tables import generate_pig_tables
5+
6+
from typing import Optional
7+
8+
9+
class TestPigTablesGeneration:
10+
@pytest.mark.parametrize(
11+
"id_col_name", [None, "col_id"]
12+
) # test None as this is the default value in generate pig tabels
13+
def test_col_id(self, id_col_name: Optional[str]):
14+
15+
# input
16+
data = pd.DataFrame(
17+
{
18+
"col_id": [0, 1, 3, 4, 6],
19+
"survived": [0, 1, 1, 0, 0],
20+
"pclass": [3, 1, 1, 3, 1],
21+
"sex": ["male", "female", "female", "male", "male"],
22+
"age": [22.0, 38.0, 35.0, 35.0, 54.0],
23+
}
24+
)
25+
target = "survived"
26+
prep_col = ["pclass", "sex", "age"]
27+
28+
# output
29+
out = generate_pig_tables(
30+
basetable=data,
31+
target_column_name=target,
32+
preprocessed_predictors=prep_col,
33+
id_column_name=id_col_name,
34+
)
35+
36+
# expected
37+
expected = pd.DataFrame(
38+
{
39+
"variable": [
40+
"age",
41+
"age",
42+
"age",
43+
"age",
44+
"pclass",
45+
"pclass",
46+
"sex",
47+
"sex",
48+
],
49+
"label": [22.0, 35.0, 38.0, 54.0, 1, 3, "female", "male"],
50+
"pop_size": [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6],
51+
"global_avg_target": [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
52+
"avg_target": [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0],
53+
}
54+
)
55+
56+
pd.testing.assert_frame_equal(out, expected)

0 commit comments

Comments
 (0)