Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dice_ml/explainer_interfaces/dice_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def predict_fn(self, input_instance):

def predict_fn_for_sparsity(self, input_instance):
"""prediction function for sparsity correction"""
input_instance = self.model.transformer.transform(input_instance).to_numpy()[0]
input_instance = self.model.transformer.transform(input_instance).to_numpy(dtype=np.float64)[0]
return self.predict_fn(torch.tensor(input_instance).float())

def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
Expand Down Expand Up @@ -418,7 +418,7 @@ def find_counterfactuals(self, query_instance, desired_class, optimizer, learnin
init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param,
posthoc_sparsity_algorithm, limit_steps_ls):
"""Finds counterfactuals by gradient-descent."""
query_instance = self.model.transformer.transform(query_instance).to_numpy()[0]
query_instance = self.model.transformer.transform(query_instance).to_numpy(dtype=np.float64)[0]
self.x1 = torch.tensor(query_instance)

# find the predicted value of query_instance
Expand Down
1 change: 1 addition & 0 deletions dice_ml/explainer_interfaces/dice_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


class DiceXGBoost(ExplainerBase):
def __init__(self, data_interface, model_interface):
"""Initialize with data and model interfaces"""
Expand Down
8 changes: 5 additions & 3 deletions dice_ml/model_interfaces/xgboost_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import xgboost as xgb
from dice_ml.model_interfaces.base_model import BaseModel

from dice_ml.constants import ModelTypes
from dice_ml.model_interfaces.base_model import BaseModel


class XGBoostModel(BaseModel):

def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None):
super().__init__(model=model, model_path=model_path, backend='xgboost', func=func, kw_args=kw_args)
if model is None and model_path:
Expand All @@ -27,4 +29,4 @@ def get_output(self, input_instance, model_score=True):
return self.model.predict(input_instance)

def get_gradient(self):
raise NotImplementedError("XGBoost does not support gradient calculation in this context")
raise NotImplementedError("XGBoost does not support gradient calculation in this context")
Binary file modified dice_ml/utils/sample_trained_models/adult.h5
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/test_dice_interface/test_dice_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def _initiate_exp_object(self, pyt_exp_object, sample_adultincome_query):
# query_instance = self.exp.data_interface.prepare_query_instance(
# query_instance=sample_adultincome_query, encoding='one-hot')
# self.query_instance = query_instance.iloc[0].values
self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(sample_adultincome_query).iloc[0].values
self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(
sample_adultincome_query).iloc[0].to_numpy(dtype=np.float64)

self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs
self.exp.target_cf_class = torch.tensor(1).float() # set desired class to 1
Expand Down
6 changes: 0 additions & 6 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,6 @@ def test_desired_class(
ans = exp.generate_counterfactuals(query_instances=sample_custom_query_2,
features_to_vary='all',
total_CFs=2, desired_class=desired_class,
proximity_weight=0.2, sparsity_weight=0.2,
diversity_weight=5.0,
categorical_penalty=0.1,
permitted_range=None)
if method != 'kdtree':
assert all(ans.cf_examples_list[0].final_cfs_df[exp.data_interface.outcome_name].values == [desired_class] * 2)
Expand All @@ -277,9 +274,6 @@ def test_desired_class(
ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2,
features_to_vary='all',
total_CFs=2, desired_class=desired_class,
proximity_weight=0.2, sparsity_weight=0.2,
diversity_weight=5.0,
categorical_penalty=0.1,
permitted_range=None)
if method != 'kdtree':
assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2)
Expand Down
Loading