Skip to content

Commit 26ff6fb

Browse files
committed
fix(ci): default BaseModel fit path to single-task when num_tasks missing
1 parent e1026c1 commit 26ff6fb

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

deepctr_torch/models/basemodel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
225225
# Train
226226
print("Train on {0} samples, validate on {1} samples, {2} steps per epoch".format(
227227
len(train_tensor_data), len(val_y), steps_per_epoch))
228+
num_tasks = getattr(self, "num_tasks", 1)
228229
for epoch in range(initial_epoch, epochs):
229230
callbacks.on_epoch_begin(epoch)
230231
epoch_logs = {}
@@ -239,15 +240,15 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
239240
y = y_train.to(self.device).float()
240241

241242
y_pred = model(x)
242-
if self.num_tasks == 1 and y_pred.ndim > 1 and y_pred.shape[-1] == 1:
243+
if num_tasks == 1 and y_pred.ndim > 1 and y_pred.shape[-1] == 1:
243244
y_pred = y_pred.squeeze(-1)
244245

245246
optim.zero_grad()
246247
if isinstance(loss_func, list):
247-
assert len(loss_func) == self.num_tasks,\
248-
"the length of `loss_func` should be equal with `self.num_tasks`"
248+
assert len(loss_func) == num_tasks,\
249+
"the length of `loss_func` should be equal with `num_tasks`"
249250
loss = sum(
250-
[loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(self.num_tasks)])
251+
[loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(num_tasks)])
251252
else:
252253
y_for_loss = y
253254
if y_for_loss.ndim > 1 and y_for_loss.shape[-1] == 1:

0 commit comments

Comments
 (0)