Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions cobra/model_building/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,30 @@ class LogisticRegressionModel:
scikit-learn logistic regression model.
predictors : list
List of predictors used in the model.
kwargs: dict, optional
Pass a dictionary here (optional!), to override Cobra's default
choice of hyperparameter values for the scikit-learn
LogisticRegression model that is used behind the scenes. Our defaults
are: fit_intercept=True, C=1e9, solver='liblinear', random_state=42.
See scikit-learn's documentation of the possible hyperparameters and
values that can be set:
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
"""

def __init__(self):
self.logit = LogisticRegression(fit_intercept=True, C=1e9,
solver='liblinear', random_state=42)
def __init__(self, **kwargs):
# Initialize a scikit-learn linear regression model,
# with custom arguments passed by the data scientist (if any),
# supplemented with Cobra's default arguments, if a custom value was
# not provided by the data scientist for overriding purposes:
default_kwargs = dict(fit_intercept=True, C=1e9, solver='liblinear',
random_state=42)
for kwarg, val in default_kwargs.items():
if kwarg not in kwargs:
kwargs[kwarg] = val
self.logit = LogisticRegression(**kwargs)

self._is_fitted = False
# placeholder to keep track of a list of predictors
self.predictors = []
self.predictors = [] # placeholder to keep track of a list of predictors
self._eval_metrics_by_split = {}

def serialize(self) -> dict:
Expand Down Expand Up @@ -258,10 +274,27 @@ class LinearRegressionModel:
scikit-learn linear regression model.
predictors : list
List of predictors used in the model.
kwargs: dict, optional
Pass a dictionary here (optional!), to override Cobra's default
choice of hyperparameter values for the scikit-learn
LinearRegression model that is used behind the scenes. Our default
setting is only fit_intercept=True.
See scikit-learn's documentation of the possible hyperparameters and
values that can be set:
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
"""

def __init__(self):
self.linear = LinearRegression(fit_intercept=True)
def __init__(self, **kwargs):
# Initialize a scikit-learn linear regression model,
# with custom arguments passed by the data scientist (if any),
# supplemented with Cobra's default arguments, if a custom value was
# not provided by the data scientist for overriding purposes:
default_kwargs = dict(fit_intercept=True)
for kwarg, val in default_kwargs.items():
if kwarg not in kwargs:
kwargs[kwarg] = val
self.linear = LinearRegression(**kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for loop can be replaced by a oneliner by using .update() (see https://www.programiz.com/python-programming/methods/dictionary/update). In that case, default_kwargs best to be renamed to something like model_kwargs.

Suggested change
default_kwargs = dict(fit_intercept=True)
for kwarg, val in default_kwargs.items():
if kwarg not in kwargs:
kwargs[kwarg] = val
self.linear = LinearRegression(**kwargs)
model_kwargs = dict(fit_intercept=True)
model_kwargs.update(kwargs)
self.linear = LinearRegression(**model_kwargs)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice Jano, I was going to commit it straight away with the github interface, but then it isn't applied to LogisitcRegression as well. I'll have a look at it this afternoon and will include Sam's remarks

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An idea to resolve duplicate code (and documentation to some extent) is to create a BaseModel class from which both the LinearRegression and LogisticRegression class inherit.
But this is of course out of scope for this PR and should be considered in a separate PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An idea to resolve duplicate code (and documentation to some extent) is to create a BaseModel class from which both the LinearRegression and LogisticRegression class inherit.

I remembered this comment while fixing #126, I also found there was quite some duplication and I had time to do this, so I've done this today with the solution to #126, yippee! :-). For details, see explanation in #128 (comment).

But this is of course out of scope for this PR and should be considered in a separate PR.

I took the liberty to include the superclassing abstraction in #126 anyway, instead of a new issue & PR dedicated to it, since the evaluate() which I was fixing unit tests for, takes up a BIG chunk of the code of both LinearRegression and LogisticRegressionModel.


self._is_fitted = False
self.predictors = [] # placeholder to keep track of a list of predictors
self._eval_metrics_by_split = {}
Expand Down