Skip to content

Commit 237820e

Browse files
authored
[doc] scikit-learn and h2o (#64)
* update the scikitlearn example * modelStudio alias for dalex Explainer * add h2o example
1 parent 9dbe671 commit 237820e

9 files changed

Lines changed: 1964 additions & 133 deletions

File tree

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(modelStudio,dalex._explainer.object.Explainer)
34
S3method(modelStudio,explainer)
5+
S3method(modelStudio,python.builtin.object)
46
export(modelStudio)
57
export(modelStudioOptions)
68
import(progress)

R/modelStudio.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,15 @@ modelStudio.explainer <- function(explainer,
384384
model_studio
385385
}
386386

387+
#:# alias for reticulate pickle/dalex Explainer
388+
#' @noRd
389+
#' @export
390+
modelStudio.python.builtin.object <- modelStudio.explainer
391+
392+
#' @noRd
393+
#' @export
394+
modelStudio.dalex._explainer.object.Explainer <- modelStudio.explainer
395+
387396
#' @noRd
388397
#' @title remove_file_paths
389398
#'

README.md

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The main `modelStudio()` function computes various (instance and dataset level)
1616
[**explain FIFA20**](https://pbiecek.github.io/explainFIFA20/) &emsp;
1717
[explain Lung Cancer](https://github.com/hbaniecki/transparent_xai/) &emsp;
1818
[**R & Python examples**](http://modelstudio.drwhy.ai/articles/vignette_examples.html) &emsp;
19-
[More Resources](https://modeloriented.github.io/modelStudio/#more) &emsp;
19+
[More Resources](http://modelstudio.drwhy.ai/#more-resources) &emsp;
2020
[**FAQ & Troubleshooting**](https://github.com/ModelOriented/modelStudio/issues/54)
2121

2222
![](man/figures/short.gif)
@@ -73,7 +73,7 @@ install.packages("iBreakDown")
7373

7474
# packages for explainer objects
7575
install.packages("DALEX")
76-
devtools::install_github("ModelOriented/DALEXtra")
76+
install.packages("DALEXtra")
7777
```
7878

7979
### mlr [dashboard](https://modeloriented.github.io/modelStudio/mlr.html)
@@ -94,7 +94,7 @@ test <- data[-index, ]
9494
# mlr ClassifTask takes target as factor
9595
train$survived <- as.factor(train$survived)
9696

97-
# prepare the model
97+
# fit a model
9898
task <- makeClassifTask(id = "titanic",
9999
data = train,
100100
target = "survived")
@@ -137,7 +137,7 @@ test <- data[-index, ]
137137
train_matrix <- model.matrix(survived ~.-1, train)
138138
test_matrix <- model.matrix(survived ~.-1, test)
139139

140-
# prepare the model
140+
# fit a model
141141
xgb_matrix <- xgb.DMatrix(train_matrix, label = train$survived)
142142
params <- list(eta = 0.01, subsample = 0.6, max_depth = 7, min_child_weight = 3,
143143
objective = "binary:logistic", eval_metric = "auc")
@@ -161,96 +161,88 @@ modelStudio(explainer,
161161

162162
### scikit-learn [dashboard](https://modeloriented.github.io/modelStudio/scikit-learn.html)
163163

164-
Use `pickle` Python module and `reticulate` R package to easily produce modelStudio for scikit-learn model.
164+
Use `pickle` Python module and `reticulate` R package to easily make a studio for a scikit-learn model.
165165

166-
In this example we fit a Pipeline MLPClassifier on the titanic data. First install the `dalex` package.
166+
In this example we will fit a Pipeline MLPClassifier model on titanic data.
167+
168+
Install the `dalex` package.
167169

168170
```bash
169171
pip3 install dalex --force
170172
```
171173

172-
Make an explainer object in Python:
174+
First, use `dalex` in Python:
173175

174176
```python
175-
# import modules
177+
# load packages and data
176178
import dalex as dx
177-
from dalex import datasets
178179

179-
from sklearn.neural_network import MLPClassifier
180-
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
181-
from sklearn.impute import SimpleImputer
180+
from sklearn.model_selection import train_test_split
182181
from sklearn.pipeline import Pipeline
183-
from sklearn.tree import DecisionTreeRegressor
182+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
183+
from sklearn.impute import SimpleImputer
184184
from sklearn.compose import ColumnTransformer
185+
from sklearn.neural_network import MLPClassifier
185186

186-
# load the data
187-
data = datasets.load_titanic()
187+
data = dx.datasets.load_titanic()
188188
X = data.drop(columns='survived')
189189
y = data.survived
190190

191-
# make a pipeline model
191+
# split the data
192+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=1)
193+
194+
# fit a pipeline model
192195
numeric_features = ['age', 'fare', 'sibsp', 'parch']
193-
numeric_transformer = Pipeline(steps=[
196+
numeric_transformer = Pipeline(
197+
steps=[
194198
('imputer', SimpleImputer(strategy='median')),
195-
('scaler', StandardScaler())])
196-
199+
('scaler', StandardScaler())
200+
]
201+
)
197202
categorical_features = ['gender', 'class', 'embarked']
198-
categorical_transformer = Pipeline(steps=[
203+
categorical_transformer = Pipeline(
204+
steps=[
199205
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
200-
('onehot', OneHotEncoder(handle_unknown='ignore'))])
206+
('onehot', OneHotEncoder(handle_unknown='ignore'))
207+
]
208+
)
201209

202210
preprocessor = ColumnTransformer(
203-
transformers=[
204-
('num', numeric_transformer, numeric_features),
205-
('cat', categorical_transformer, categorical_features)])
206-
207-
208-
clf = Pipeline(steps=[('preprocessor', preprocessor),
209-
('classifier', MLPClassifier(hidden_layer_sizes=(150,100,50),
210-
max_iter=500, random_state=0))])
211-
212-
clf.fit(X, y)
211+
transformers=[
212+
('num', numeric_transformer, numeric_features),
213+
('cat', categorical_transformer, categorical_features)
214+
]
215+
)
216+
217+
model = Pipeline(
218+
steps=[
219+
('preprocessor', preprocessor),
220+
('classifier', MLPClassifier(hidden_layer_sizes=(150,100,50), max_iter=500, random_state=0))
221+
]
222+
)
223+
model.fit(X_train, y_train)
213224

214-
# make an explainer
215-
explainer = dx.Explainer(clf, X, y)
225+
# create an explainer for the model
226+
explainer = dx.Explainer(model, X_test, y_test, label = 'scikit-learn')
216227

217-
# remove these functions before dump
228+
#! remove residual_function before dump !
218229
explainer.residual_function = None
219-
explainer.predict_function = None
220230

221231
# pack the explainer into a pickle file
222-
import pickle
223-
pickle_out = open("explainer_titanic.pickle","wb")
232+
import pickle
233+
pickle_out = open("explainer_scikitlearn.pickle","wb")
224234
pickle.dump(explainer, pickle_out)
225-
pickle_out.close()
235+
pickle_out.close()
226236
```
227237

228-
Then use `modelStudio` in R:
238+
Then, use `modelStudio` in R:
229239

230240
```r
231-
# use reticulate to load the explainer from a pickle file
241+
# load the explainer from the pickle file
232242
library(reticulate)
233-
explainer <- py_load_object('explainer_titanic.pickle')
234-
235-
# make a predict_function
236-
predict_function <- function(model, data) {
237-
if ("predict_proba" %in% names(model)) {
238-
pred <- model$predict_proba(data)
239-
if (ncol(pred) == 2) {
240-
pred <- pred[,2]
241-
}
242-
} else {
243-
pred <- model$predict(data)
244-
}
245-
pred
246-
}
247-
248-
# adjust the explainer
249-
explainer$predict_function <- predict_function
250-
explainer$label <- 'scikit-learn'
251-
class(explainer) <- c(class(explainer), 'explainer')
252-
253-
# make a modelStudio
243+
explainer <- py_load_object('explainer_scikitlearn.pickle', pickle = "pickle")
244+
245+
# make a studio for the model
254246
library(modelStudio)
255247
modelStudio(explainer)
256248
```

inst/WORDLIST

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ tensorflow
2828
Shapley
2929
cran
3030
CRAN
31-
MLPCLassifier
31+
MLPClassifier
3232
keras
3333
lightGBM
34+
customizable

pkgdown/_pkgdown.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ template:
33
default_assets: false
44
params:
55
ganalytics: UA-5650686-14
6-
noindex: true
6+
noindex: true
7+

0 commit comments

Comments
 (0)