Skip to content

Commit 07338fb

Browse files
author
sborms
committed
add model_type functionality in PreProcessor class & tests
1 parent ecb7d90 commit 07338fb

3 files changed

Lines changed: 39 additions & 26 deletions

File tree

cobra/preprocessing/categorical_data_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class CategoricalDataProcessor(BaseEstimator):
4646
keep_missing : bool
4747
Whether or not to keep missing as a separate category.
4848
model_type : str
49-
Model type ("classification" or "regression").
49+
Model type (``classification`` or ``regression``).
5050
p_value_threshold : float
5151
Significance threshold for regrouping.
5252
regroup : bool
@@ -442,7 +442,7 @@ def _compute_p_value(X: pd.Series, y: pd.Series, category: str,
442442
category : str
443443
Category for which we carry out the test.
444444
model_type : str
445-
Model type ("classification" or "regression").
445+
Model type (``classification`` or ``regression``).
446446
scale_contingency_table : bool
447447
Whether we scale contingency table with incidence rate.
448448
Only used when model_type = "classification".

cobra/preprocessing/preprocessor.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,27 @@ class PreProcessor(BaseEstimator):
4444
----------
4545
categorical_data_processor : CategoricalDataProcessor
4646
Instance of CategoricalDataProcessor to do the preprocessing of
47-
categorical variables
47+
categorical variables. The model_type variable is specified
48+
here (``classification`` or ``regression``).
4849
discretizer : KBinsDiscretizer
4950
Instance of KBinsDiscretizer to do the prepocessing of continuous
50-
variables by means of discretization
51+
variables by means of discretization.
5152
serialization_path : str
52-
path to save the pipeline to
53+
Path to save the pipeline to.
5354
stratify_split : bool
54-
Whether or not to stratify the train-test split
55+
Whether or not to stratify the train-test split.
5556
target_encoder : TargetEncoder
56-
Instance of TargetEncoder to do the incidence replacement
57+
Instance of TargetEncoder to do the incidence replacement.
5758
"""
5859

59-
def __init__(self, categorical_data_processor: CategoricalDataProcessor,
60+
def __init__(self,
61+
categorical_data_processor: CategoricalDataProcessor,
6062
discretizer: KBinsDiscretizer,
6163
target_encoder: TargetEncoder,
6264
is_fitted: bool = False):
6365

66+
self.model_type = categorical_data_processor.model_type
67+
6468
self._categorical_data_processor = categorical_data_processor
6569
self._discretizer = discretizer
6670
self._target_encoder = target_encoder
@@ -69,6 +73,7 @@ def __init__(self, categorical_data_processor: CategoricalDataProcessor,
6973

7074
@classmethod
7175
def from_params(cls,
76+
model_type: str = "classification",
7277
n_bins: int = 10,
7378
strategy: str = "quantile",
7479
closed: str = "right",
@@ -91,16 +96,18 @@ def from_params(cls,
9196
9297
Parameters
9398
----------
99+
model_type : str
100+
Model type (``classification`` or ``regression``).
94101
n_bins : int, optional
95102
Number of bins to produce. Raises ValueError if ``n_bins < 2``.
96103
strategy : str, optional
97104
Binning strategy. Currently only ``uniform`` and ``quantile``
98-
e.g. equifrequency is supported
105+
e.g. equifrequency is supported.
99106
closed : str, optional
100-
Whether to close the bins (intervals) from the left or right
107+
Whether to close the bins (intervals) from the left or right.
101108
auto_adapt_bins : bool, optional
102-
reduces the number of bins (starting from n_bins) as a function of
103-
the number of missings
109+
Reduces the number of bins (starting from n_bins) as a function of
110+
the number of missings.
104111
starting_precision : int, optional
105112
Initial precision for the bin edges to start from,
106113
can also be negative. Given a list of bin edges, the class will
@@ -110,33 +117,32 @@ def from_params(cls,
110117
will be made to round up the numbers of the bin edges
111118
e.g. ``5.55 -> 10``, ``146 -> 100``, ...
112119
label_format : str, optional
113-
format string to display the bin labels
120+
Format string to display the bin labels
114121
e.g. ``min - max``, ``(min, max]``, ...
115122
change_endpoint_format : bool, optional
116123
Whether or not to change the format of the lower and upper bins
117124
into ``< x`` and ``> y`` resp.
118125
regroup : bool
119-
Whether or not to regroup categories
126+
Whether or not to regroup categories.
120127
regroup_name : str
121-
New name of the non-significant regrouped variables
128+
New name of the non-significant regrouped variables.
122129
keep_missing : bool
123-
Whether or not to keep missing as a separate category
130+
Whether or not to keep missing as a separate category.
124131
category_size_threshold : int
125-
minimal size of a category to keep it as a separate category
132+
Minimal size of a category to keep it as a separate category.
126133
p_value_threshold : float
127134
Significance threshold for regrouping.
128135
forced_categories : dict
129136
Map to prevent certain categories from being group into ``Other``
130137
for each column - dict of the form ``{col:[forced vars]}``.
131138
scale_contingency_table : bool
132-
Whether contingency table should be scaled before chi^2.'
139+
Whether contingency table should be scaled before chi^2.
133140
weight : float, optional
134141
Smoothing parameters (non-negative). The higher the value of the
135142
parameter, the bigger the contribution of the overall mean.
136-
When set to zero, there is no smoothing
137-
(e.g. the pure target incidence is used).
143+
When set to zero, there is no smoothing (e.g. the pure target incidence is used).
138144
imputation_strategy : str, optional
139-
in case there is a particular column which contains new categories,
145+
In case there is a particular column which contains new categories,
140146
the encoding will lead to NULL values which should be imputed.
141147
Valid strategies are to replace with the global mean of the train
142148
set or the min (resp. max) incidence of the categories of that
@@ -145,25 +151,29 @@ def from_params(cls,
145151
Returns
146152
-------
147153
PreProcessor
148-
Description
154+
class encapsulating CategoricalDataProcessor,
155+
KBinsDiscretizer, and TargetEncoder instances
149156
"""
150157
categorical_data_processor = CategoricalDataProcessor(
158+
model_type,
151159
regroup,
152160
regroup_name,
153161
keep_missing,
154162
category_size_threshold,
155163
p_value_threshold,
156164
scale_contingency_table,
157165
forced_categories)
166+
158167
discretizer = KBinsDiscretizer(n_bins, strategy, closed,
159168
auto_adapt_bins,
160169
starting_precision,
161170
label_format,
162171
change_endpoint_format)
163172

164-
target_encoder = TargetEncoder(weight)
173+
target_encoder = TargetEncoder(weight, imputation_strategy)
165174

166-
return cls(categorical_data_processor, discretizer, target_encoder)
175+
return cls(model_type,
176+
categorical_data_processor, discretizer, target_encoder)
167177

168178
@classmethod
169179
def from_pipeline(cls, pipeline: dict):
@@ -187,20 +197,22 @@ def from_pipeline(cls, pipeline: dict):
187197
"""
188198

189199
if not PreProcessor._is_valid_pipeline(pipeline):
190-
raise ValueError("Invalid pipeline") # To do: specify error
200+
raise ValueError("Invalid pipeline") ## TODO: specify error
191201

192202
categorical_data_processor = CategoricalDataProcessor()
193203
categorical_data_processor.set_attributes_from_dict(
194204
pipeline["categorical_data_processor"]
195205
)
206+
model_type = categorical_data_processor.model_type
196207

197208
discretizer = KBinsDiscretizer()
198209
discretizer.set_attributes_from_dict(pipeline["discretizer"])
199210

200211
target_encoder = TargetEncoder()
201212
target_encoder.set_attributes_from_dict(pipeline["target_encoder"])
202213

203-
return cls(categorical_data_processor, discretizer, target_encoder,
214+
return cls(model_type,
215+
categorical_data_processor, discretizer, target_encoder,
204216
is_fitted=pipeline["_is_fitted"])
205217

206218
def fit(self, train_data: pd.DataFrame, continuous_vars: list,

tests/preprocessing/test_preprocessor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def test_is_valid_pipeline(self, injection_location: str,
9797
# is_valid_pipeline only checks for relevant keys atm
9898
pipeline_dict = {
9999
"categorical_data_processor": {
100+
"model_type": None,
100101
"regroup": None,
101102
"regroup_name": None,
102103
"keep_missing": None,

0 commit comments

Comments
 (0)