Skip to content

Commit e1026c1

Browse files
committed
fix(multitask): handle batch_size=1 safely in fit/gating
- keep Linear accumulator on input device to avoid cross-device errors - avoid global squeeze in BaseModel.fit for multi-task outputs - use squeeze(1) in MMOE/PLE expert-gating outputs - add batch_size=1 regression tests for MMOE and PLE
1 parent 761a175 commit e1026c1

5 files changed

Lines changed: 42 additions & 6 deletions

File tree

deepctr_torch/models/basemodel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def forward(self, X, sparse_feat_refine_weight=None):
7171

7272
sparse_embedding_list += varlen_embedding_list
7373

74-
linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
74+
# Keep accumulator on the same device as current input tensor.
75+
linear_logit = X.new_zeros((X.shape[0], 1))
7576
if len(sparse_embedding_list) > 0:
7677
sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
7778
if sparse_feat_refine_weight is not None:
@@ -237,7 +238,9 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
237238
x = x_train.to(self.device).float()
238239
y = y_train.to(self.device).float()
239240

240-
y_pred = model(x).squeeze()
241+
y_pred = model(x)
242+
if self.num_tasks == 1 and y_pred.ndim > 1 and y_pred.shape[-1] == 1:
243+
y_pred = y_pred.squeeze(-1)
241244

242245
optim.zero_grad()
243246
if isinstance(loss_func, list):
@@ -246,7 +249,10 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
246249
loss = sum(
247250
[loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(self.num_tasks)])
248251
else:
249-
loss = loss_func(y_pred, y.squeeze(), reduction='sum')
252+
y_for_loss = y
253+
if y_for_loss.ndim > 1 and y_for_loss.shape[-1] == 1:
254+
y_for_loss = y_for_loss.squeeze(-1)
255+
loss = loss_func(y_pred, y_for_loss, reduction='sum')
250256
reg_loss = self.get_regularization_loss()
251257

252258
total_loss = loss + reg_loss + self.aux_loss

deepctr_torch/models/multitask/mmoe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def forward(self, X):
127127
else:
128128
gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
129129
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), expert_outs) # (bs, 1, dim)
130-
mmoe_outs.append(gate_mul_expert.squeeze())
130+
mmoe_outs.append(gate_mul_expert.squeeze(1))
131131

132132
# tower dnn (task-specific)
133133
task_outs = []

deepctr_torch/models/multitask/ple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def cgc_net(self, inputs, level_num):
177177
else:
178178
gate_dnn_out = self.specific_gate_dnn_final_layer[level_num][i](inputs[i])
179179
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim)
180-
cgc_outs.append(gate_mul_expert.squeeze())
180+
cgc_outs.append(gate_mul_expert.squeeze(1))
181181

182182
# gates for shared experts
183183
cur_experts_outputs = specific_expert_outputs + shared_expert_outputs
@@ -189,7 +189,7 @@ def cgc_net(self, inputs, level_num):
189189
else:
190190
gate_dnn_out = self.shared_gate_dnn_final_layer[level_num](inputs[-1])
191191
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim)
192-
cgc_outs.append(gate_mul_expert.squeeze())
192+
cgc_outs.append(gate_mul_expert.squeeze(1))
193193

194194
return cgc_outs
195195

tests/models/multitask/MMOE_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,20 @@ def test_MMOE(num_experts, expert_dnn_hidden_units, gate_dnn_hidden_units, tower
2929
check_mtl_model(model, model_name, x, y_list, task_types)
3030

3131

32+
def test_MMOE_batch_size_one_multitask_fit():
33+
sample_size = 8
34+
x, y_list, feature_columns = get_mtl_test_data(
35+
sample_size, sparse_feature_num=2, dense_feature_num=1, task_types=['binary', 'binary'])
36+
37+
model = MMOE(feature_columns, task_types=['binary', 'binary'], device=get_device(use_cuda=False))
38+
model.compile('adam', ['binary_crossentropy', 'binary_crossentropy'], metrics=['binary_crossentropy'])
39+
40+
history = model.fit(x, y_list, batch_size=1, epochs=1, verbose=0)
41+
assert "loss" in history.history
42+
43+
pred = model.predict(x, batch_size=1)
44+
assert pred.shape == (sample_size, 2)
45+
46+
3247
if __name__ == "__main__":
3348
pass

tests/models/multitask/PLE_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,20 @@ def test_PLE(shared_expert_num, specific_expert_num, num_levels, expert_dnn_hidd
3030
check_mtl_model(model, model_name, x, y_list, task_types)
3131

3232

33+
def test_PLE_batch_size_one_multitask_fit():
34+
sample_size = 8
35+
x, y_list, feature_columns = get_mtl_test_data(
36+
sample_size, sparse_feature_num=2, dense_feature_num=1, task_types=['binary', 'binary'])
37+
38+
model = PLE(feature_columns, task_types=['binary', 'binary'], device=get_device(use_cuda=False))
39+
model.compile('adam', ['binary_crossentropy', 'binary_crossentropy'], metrics=['binary_crossentropy'])
40+
41+
history = model.fit(x, y_list, batch_size=1, epochs=1, verbose=0)
42+
assert "loss" in history.history
43+
44+
pred = model.predict(x, batch_size=1)
45+
assert pred.shape == (sample_size, 2)
46+
47+
3348
if __name__ == "__main__":
3449
pass

0 commit comments

Comments
 (0)