Skip to content

Commit 353b9bc

Browse files
author
Patrick Leonardy
committed
Defaults id_column to None for PIGs & tests
1 parent f37867f commit 353b9bc

2 files changed

Lines changed: 66 additions & 17 deletions

File tree

cobra/evaluation/pigs_tables.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import cobra.utils as utils
99

1010
def generate_pig_tables(basetable: pd.DataFrame,
11-
id_column_name: str,
1211
target_column_name: str,
13-
preprocessed_predictors: list) -> pd.DataFrame:
12+
preprocessed_predictors: list,
13+
id_column_name: str = None) -> pd.DataFrame:
1414
"""Compute PIG tables for all predictors in preprocessed_predictors.
1515
1616
The output is a DataFrame with columns ``variable``, ``label``,
@@ -20,35 +20,41 @@ def generate_pig_tables(basetable: pd.DataFrame,
2020
----------
2121
basetable : pd.DataFrame
2222
Basetable to compute PIG tables from.
23-
id_column_name : str
24-
Name of the basetable column containing the IDs of the basetable rows
25-
(e.g. customernumber).
2623
target_column_name : str
2724
Name of the basetable column containing the target values to predict.
2825
preprocessed_predictors: list
2926
List of basetable column names containing preprocessed predictors.
30-
27+
id_column_name : str, default=None
28+
Name of the basetable column containing the IDs of the basetable rows
29+
(e.g. customernumber).
3130
Returns
3231
-------
3332
pd.DataFrame
3433
DataFrame containing a PIG table for all predictors.
3534
"""
35+
36+
#check if there is a id-column and define no_predictor accordingly
37+
if id_column_name == None:
38+
no_predictor = [target_column_name]
39+
else:
40+
no_predictor = [id_column_name, target_column_name]
41+
42+
3643
pigs = [
3744
compute_pig_table(basetable,
3845
column_name,
3946
target_column_name,
40-
id_column_name)
47+
)
4148
for column_name in sorted(preprocessed_predictors)
42-
if column_name not in [id_column_name, target_column_name]
49+
if column_name not in no_predictor
4350
]
44-
output = pd.concat(pigs)
51+
output = pd.concat(pigs, ignore_index=True)
4552
return output
4653

4754

4855
def compute_pig_table(basetable: pd.DataFrame,
4956
predictor_column_name: str,
50-
target_column_name: str,
51-
id_column_name: str) -> pd.DataFrame:
57+
target_column_name: str) -> pd.DataFrame:
5258
"""Compute the PIG table of a given predictor for a given target.
5359
5460
Parameters
@@ -59,8 +65,6 @@ def compute_pig_table(basetable: pd.DataFrame,
5965
Predictor name of which to compute the pig table.
6066
target_column_name : str
6167
Name of the target variable.
62-
id_column_name : str
63-
Name of the id column (used to count population size).
6468
6569
Returns
6670
-------
@@ -72,12 +76,18 @@ def compute_pig_table(basetable: pd.DataFrame,
7276
# group by the binned variable, compute the incidence
7377
# (= mean of the target for the given bin) and compute the bin size
7478
# (e.g. COUNT(id_column_name)). After that, rename the columns
79+
7580
res = (basetable.groupby(predictor_column_name)
76-
.agg({target_column_name: "mean", id_column_name: "size"})
81+
.agg(
82+
avg_target = (target_column_name, "mean"),
83+
pop_size = (target_column_name, "size")
84+
)
7785
.reset_index()
78-
.rename(columns={predictor_column_name: "label",
79-
target_column_name: "avg_target",
80-
id_column_name: "pop_size"}))
86+
.rename(
87+
columns={predictor_column_name: "label"}
88+
)
89+
)
90+
8191

8292
# add the column name to a variable column
8393
# add the average incidence
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
import pandas as pd
4+
from cobra.evaluation.pigs_tables import generate_pig_tables
5+
6+
class TestPigTablesGeneration:
7+
8+
@pytest.mark.parametrize("id_col_name", [None, "col_id"]) # test None as this is the default value in generate pig tabels
9+
def test_col_id(self, id_col_name):
10+
11+
# input
12+
data = pd.DataFrame({
13+
'col_id': [0, 1, 3, 4, 6],
14+
'survived': [0, 1, 1, 0, 0],
15+
'pclass': [3, 1, 1, 3, 1],
16+
'sex': ['male', 'female', 'female', 'male', 'male'],
17+
'age': [22.0, 38.0, 35.0, 35.0, 54.0]
18+
})
19+
target = "survived"
20+
prep_col = ["pclass", "sex", "age"]
21+
22+
# output
23+
out = generate_pig_tables(
24+
basetable= data,
25+
target_column_name=target,
26+
preprocessed_predictors=prep_col,
27+
id_column_name=id_col_name
28+
)
29+
30+
# expected
31+
expected = pd.DataFrame({
32+
'variable': ['age', 'age', 'age', 'age', 'pclass', 'pclass', 'sex', 'sex'],
33+
'label': [22.0, 35.0, 38.0, 54.0, 1, 3, 'female', 'male'],
34+
'pop_size': [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6],
35+
'global_avg_target': [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
36+
'avg_target': [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0]
37+
})
38+
39+
pd.testing.assert_frame_equal(out, expected)

0 commit comments

Comments
 (0)