55
66from typing import Optional
77
8- class TestPigTablesGeneration :
98
10- @pytest .mark .parametrize ("id_col_name" , [None , "col_id" ]) # test None as this is the default value in generate pig tabels
9+ class TestPigTablesGeneration :
10+ @pytest .mark .parametrize (
11+ "id_col_name" , [None , "col_id" ]
12+ ) # test None as this is the default value in generate pig tabels
1113 def test_col_id (self , id_col_name : Optional [str ]):
12-
14+
1315 # input
14- data = pd .DataFrame ({
15- 'col_id' : [0 , 1 , 3 , 4 , 6 ],
16- 'survived' : [0 , 1 , 1 , 0 , 0 ],
17- 'pclass' : [3 , 1 , 1 , 3 , 1 ],
18- 'sex' : ['male' , 'female' , 'female' , 'male' , 'male' ],
19- 'age' : [22.0 , 38.0 , 35.0 , 35.0 , 54.0 ]
20- })
16+ data = pd .DataFrame (
17+ {
18+ "col_id" : [0 , 1 , 3 , 4 , 6 ],
19+ "survived" : [0 , 1 , 1 , 0 , 0 ],
20+ "pclass" : [3 , 1 , 1 , 3 , 1 ],
21+ "sex" : ["male" , "female" , "female" , "male" , "male" ],
22+ "age" : [22.0 , 38.0 , 35.0 , 35.0 , 54.0 ],
23+ }
24+ )
2125 target = "survived"
2226 prep_col = ["pclass" , "sex" , "age" ]
23-
27+
2428 # output
2529 out = generate_pig_tables (
26- basetable = data ,
30+ basetable = data ,
2731 target_column_name = target ,
2832 preprocessed_predictors = prep_col ,
29- id_column_name = id_col_name
33+ id_column_name = id_col_name ,
3034 )
31-
35+
3236 # expected
33- expected = pd .DataFrame ({
34- 'variable' : ['age' , 'age' , 'age' , 'age' , 'pclass' , 'pclass' , 'sex' , 'sex' ],
35- 'label' : [22.0 , 35.0 , 38.0 , 54.0 , 1 , 3 , 'female' , 'male' ],
36- 'pop_size' : [0.2 , 0.4 , 0.2 , 0.2 , 0.6 , 0.4 , 0.4 , 0.6 ],
37- 'global_avg_target' : [0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 ],
38- 'avg_target' : [0.0 , 0.5 , 1.0 , 0.0 , 0.6666666666666666 , 0.0 , 1.0 , 0.0 ]
39- })
37+ expected = pd .DataFrame (
38+ {
39+ "variable" : [
40+ "age" ,
41+ "age" ,
42+ "age" ,
43+ "age" ,
44+ "pclass" ,
45+ "pclass" ,
46+ "sex" ,
47+ "sex" ,
48+ ],
49+ "label" : [22.0 , 35.0 , 38.0 , 54.0 , 1 , 3 , "female" , "male" ],
50+ "pop_size" : [0.2 , 0.4 , 0.2 , 0.2 , 0.6 , 0.4 , 0.4 , 0.6 ],
51+ "global_avg_target" : [0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 , 0.4 ],
52+ "avg_target" : [0.0 , 0.5 , 1.0 , 0.0 , 0.6666666666666666 , 0.0 , 1.0 , 0.0 ],
53+ }
54+ )
4055
4156 pd .testing .assert_frame_equal (out , expected )
42-
0 commit comments