@@ -34,22 +34,37 @@ class TargetEncoder(BaseEstimator):
3434
3535 Attributes
3636 ----------
37- columns : list
38- A list of columns to encode, if None, all string columns will be
39- encoded.
37+ imputation_strategy : str
38+ in case there is a particular column which contains new categories,
39+ the encoding will lead to NULL values which should be imputed.
40+ Valid strategies are to replace with the global mean of the train
41+ set or the min (resp. max) incidence of the categories of that
42+ particular variable.
4043 weight : float
4144 Smoothing parameters (non-negative). The higher the value of the
4245 parameter, the bigger the contribution of the overall mean. When set to
4346 zero, there is no smoothing (e.g. the pure target incidence is used).
4447 """
4548
46- def __init__ (self , weight : float = 0.0 ):
49+ valid_strategies = ("mean" , "min" , "max" )
50+
51+ def __init__ (self , weight : float = 0.0 ,
52+ imputation_strategy : str = "mean" ):
4753
4854 if weight < 0 :
4955 raise ValueError ("The value of weight cannot be smaller than zero" )
56+ elif imputation_strategy not in self .valid_strategies :
57+ raise ValueError ("Valid options for 'imputation_strategy' are {}."
58+ " Got imputation_strategy={!r} instead"
59+ .format (self .valid_strategies ,
60+ imputation_strategy ))
5061
5162 self .weight = weight
63+ self .imputation_strategy = imputation_strategy
64+
5265 self ._mapping = {} # placeholder for fitted output
66+ # placeholder for the global incidence of the data used for fitting
67+ self ._global_mean = None
5368
5469 # not implemented yet!
5570 # randomized: bool=False, sigma=0.05
@@ -72,6 +87,8 @@ def attributes_to_dict(self) -> dict:
7287 for key , value in self ._mapping .items ()
7388 }
7489
90+ params ["_global_mean" ] = self ._global_mean
91+
7592 return params
7693
7794 def set_attributes_from_dict (self , params : dict ):
@@ -88,6 +105,14 @@ def set_attributes_from_dict(self, params: dict):
88105 if "weight" in params and type (params ["weight" ]) == float :
89106 self .weight = params ["weight" ]
90107
108+ if ("imputation_strategy" in params and
109+ params ["imputation_strategy" ] in self .valid_strategies ):
110+
111+ self .imputation_strategy = params ["imputation_strategy" ]
112+
113+ if "_global_mean" in params and type (params ["_global_mean" ]) == float :
114+ self ._global_mean = params ["_global_mean" ]
115+
91116 _mapping = {}
92117 if "_mapping" in params and type (params ["_mapping" ]) == dict :
93118 _mapping = params ["_mapping" ]
@@ -121,19 +146,17 @@ def fit(self, data: pd.DataFrame, column_names: list,
121146
122147 # compute global mean (target incidence in case of binary target)
123148 y = data [target_column ]
124- global_mean = y .sum () / y .count ()
149+ self . _global_mean = y .sum () / y .count ()
125150
126151 for column in column_names :
127152 if column not in data .columns :
128153 log .warning ("DataFrame has no column '{}', so it will be "
129154 "skipped in fitting" .format (column ))
130155 continue
131156
132- self ._mapping [column ] = self ._fit_column (data [column ], y ,
133- global_mean )
157+ self ._mapping [column ] = self ._fit_column (data [column ], y )
134158
135- def _fit_column (self , X : pd .Series , y : pd .Series ,
136- global_mean : float ) -> pd .Series :
159+ def _fit_column (self , X : pd .Series , y : pd .Series ) -> pd .Series :
137160 """Summary
138161
139162 Parameters
@@ -143,8 +166,6 @@ def _fit_column(self, X: pd.Series, y: pd.Series,
143166 categorical variable.
144167 y : pd.Series
145168 series containing the targets for each observation
146- global_mean : float
147- Global mean of the target
148169
149170 Returns
150171 -------
@@ -158,7 +179,9 @@ def _fit_column(self, X: pd.Series, y: pd.Series,
158179 # Q: do we need to do this here or during the transform phase???
159180
160181 # Note if self.weight = 0, we have the ordinary incidence replacement
161- numerator = stats ["count" ]* stats ["mean" ] + self .weight * global_mean
182+ numerator = (stats ["count" ]* stats ["mean" ]
183+ + self .weight * self ._global_mean )
184+
162185 denominator = stats ["count" ] + self .weight
163186
164187 return numerator / denominator
@@ -187,13 +210,12 @@ def transform(self, data: pd.DataFrame,
187210 method
188211
189212 """
190- if len (self ._mapping ) == 0 :
213+ if ( len (self ._mapping ) == 0 ) or ( self . _global_mean is None ) :
191214 msg = ("This {} instance is not fitted yet. Call 'fit' with "
192215 "appropriate arguments before using this method." )
193216
194217 raise NotFittedError (msg .format (self .__class__ .__name__ ))
195218
196- new_columns = []
197219 for column in column_names :
198220
199221 if column not in data .columns :
@@ -205,15 +227,47 @@ def transform(self, data: pd.DataFrame,
205227 "and will be skipped" .format (column ))
206228 continue
207229
208- new_column = TargetEncoder ._clean_column_name (column )
230+ data = self ._transform_column (data , column )
231+
232+ return data
233+
234+ def _transform_column (self , data : pd .DataFrame ,
235+ column_name : str ) -> pd .DataFrame :
236+ """Replace (e.g. encode) categories of each column with its average
237+ incidence which was computed when the fit method was called
209238
210- # Convert dtype to float because when the original dtype
211- # is of type "category", the resulting dtype is also of type
212- # "category"
213- data [new_column ] = (data [column ].map (self ._mapping [column ])
214- .astype ("float" ))
239+ Parameters
240+ ----------
241+ X : pd.DataFrame
242+ data to encode
243+ column_name : str
244+ Name of the column in data to be encoded
215245
216- new_columns .append (new_column )
246+ Returns
247+ -------
248+ pd.DataFrame
249+ transformed data
250+ """
251+ new_column = TargetEncoder ._clean_column_name (column_name )
252+
253+ # Convert dtype to float because when the original dtype
254+ # is of type "category", the resulting dtype is also of type
255+ # "category"
256+ data [new_column ] = (data [column_name ].map (self ._mapping [column_name ])
257+ .astype ("float" ))
258+
259+ # In case of categorical data, it could be that new categories will
260+ # emerge which were not present in the train set, so this will result
261+ # in missing values (which should be replaced)
262+ if data [new_column ].isnull ().sum () > 0 :
263+ if self .imputation_strategy == "mean" :
264+ data [new_column ].fillna (self ._global_mean , inplace = True )
265+ elif self .imputation_strategy == "min" :
266+ data [new_column ].fillna (data [new_column ].min (),
267+ inplace = True )
268+ elif self .imputation_strategy == "max" :
269+ data [new_column ].fillna (data [new_column ].max (),
270+ inplace = True )
217271
218272 return data
219273
0 commit comments