@@ -315,6 +315,8 @@ class DecisionTreeLogisticRegression(BaseEstimator, ClassifierMixin):
315315 where *p* is the proportion of samples falling in the first
316316 fold.
317317 :param verbose: prints out information about the training
318+ :param strategy: `'parallel'` or `'perpendicular'`,
319+ see below
318320
319321 Fitted attributes:
320322
@@ -323,6 +325,14 @@ class DecisionTreeLogisticRegression(BaseEstimator, ClassifierMixin):
323325 or a list of arrays of class labels (multi-output problem).
324326 * `tree_`: Tree
325327 The underlying Tree object.
328+
329+ The class implements two strategies to build the tree.
330+ The first one `'parallel'` splits the feature space using
331+ the hyperplan defined by a logistic regression, the second
332+ strategy `'perpendicular'` splis the feature space based on
333+ a hyperplan perpendicular to a logistic regression. By doing
334+ this, two logistic regression fit on both sub parts must
335+ necessary decreases the training error.
326336 """
327337
328338 _fit_improve_algo_values = (
@@ -332,7 +342,7 @@ def __init__(self, estimator=None,
332342 max_depth = 20 , min_samples_split = 2 ,
333343 min_samples_leaf = 2 , min_weight_fraction_leaf = 0.0 ,
334344 fit_improve_algo = 'auto' , p1p2 = 0.09 ,
335- gamma = 1. , verbose = 0 ):
345+ gamma = 1. , verbose = 0 , strategy = 'parallel' ):
336346 "constructor"
337347 ClassifierMixin .__init__ (self )
338348 BaseEstimator .__init__ (self )
@@ -354,6 +364,7 @@ def __init__(self, estimator=None,
354364 self .p1p2 = p1p2
355365 self .gamma = gamma
356366 self .verbose = verbose
367+ self .strategy = strategy
357368
358369 if self .fit_improve_algo not in DecisionTreeLogisticRegression ._fit_improve_algo_values :
359370 raise ValueError (
@@ -392,13 +403,27 @@ def fit(self, X, y, sample_weight=None):
392403 raise RuntimeError (
393404 "The model only supports binary classification but labels are "
394405 "{}." .format (self .classes_ ))
406+
407+ if self .strategy == 'parallel' :
408+ return self ._fit_parallel (X , y , sample_weight )
409+ if self .strategy == 'perpendicular' :
410+ return self ._fit_perpendicular (X , y , sample_weight )
411+ raise ValueError (
412+ "Unknown strategy '{}'." .format (self .strategy ))
413+
414+ def _fit_parallel (self , X , y , sample_weight ):
415+ "Implements the parallel strategy."
395416 cls = (y == self .classes_ [1 ]).astype (numpy .int32 )
396417 estimator = clone (self .estimator )
397418 self .tree_ = _DecisionTreeLogisticRegressionNode (estimator , 0.5 )
398419 self .n_nodes_ = self .tree_ .fit (
399420 X , cls , sample_weight , self , X .shape [0 ]) + 1
400421 return self
401422
423+ def _fit_perpendicular (self , X , y , sample_weight ):
424+ "Implements the perpendicular strategy."
425+ raise NotImplementedError ()
426+
402427 def predict (self , X ):
403428 """
404429 Runs the predictions.
0 commit comments