Skip to content

Commit 1c55f7b

Browse files
Fix variable name (#341)
* Fix variable name Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> * Fix variable name Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com>
1 parent d176e12 commit 1c55f7b

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

dice_ml/explainer_interfaces/feasible_base_vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ def train(self, pre_trained=False):
138138

139139
train_dataset = torch.tensor(self.vae_train_feat).float()
140140
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
141-
for train_x in enumerate(train_dataset):
141+
for train in enumerate(train_dataset):
142142
self.cf_vae_optimizer.zero_grad()
143143

144-
train_x = train_x[1]
144+
train_x = train[1]
145145
train_y = 1.0-torch.argmax(self.pred_model(train_x), dim=1)
146146
train_size += train_x.shape[0]
147147

dice_ml/explainer_interfaces/feasible_model_approx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def train(self, constraint_type, constraint_variables, constraint_direction, con
8383

8484
train_dataset = torch.tensor(self.vae_train_feat).float()
8585
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
86-
for train_x in enumerate(train_dataset):
86+
for train in enumerate(train_dataset):
8787
self.cf_vae_optimizer.zero_grad()
8888

89-
train_x = train_x[1]
89+
train_x = train[1]
9090
train_y = 1.0-torch.argmax(self.pred_model(train_x), dim=1)
9191
train_size += train_x.shape[0]
9292

0 commit comments

Comments
 (0)