Skip to content

Compat refresh for modern PyTorch/Python + multi-task batch_size=1 fixes#310

Merged
shenweichen merged 15 commits intomasterfrom
torch2.x
Apr 18, 2026
Merged

Compat refresh for modern PyTorch/Python + multi-task batch_size=1 fixes#310
shenweichen merged 15 commits intomasterfrom
torch2.x

Conversation

@shenweichen
Copy link
Copy Markdown
Owner

Summary

This PR refreshes DeepCTR-Torch compatibility for modern environments and fixes multi-task edge cases observed on newer PyTorch/scikit-learn stacks.

Core compatibility updates

  • remove TensorFlow hard dependency from install requirements
  • replace TensorFlow private callback imports with native deepctr_torch.callbacks implementations
  • update package constraints and metadata for modern Python/PyTorch support
  • improve docs/RTD configuration and markdown parser support

Runtime/training fixes

  • fix metric input handling for multi-task outputs in BaseModel evaluation/logging path
  • fix batch_size=1 multi-task training shape issues by avoiding unsafe global squeeze() in fit path
  • keep expert-gating outputs shape-safe in MMOE/PLE (squeeze(1) instead of global squeeze)
  • initialize linear accumulator on input tensor device to avoid cross-device mismatch under multi-GPU setups

CI/tests

  • add/adjust smoke tests for examples in CI
  • add regression coverage for multi-task batch_size=1 behavior

Linked Issues

Closes #281
Closes #288
Closes #303
Closes #305
Closes #263
Closes #309
Refs #301

Validation

  • python -m pytest -q tests/models/multitask/MMOE_test.py tests/models/multitask/PLE_test.py -k batch_size_one
  • python -m pytest -q tests/models/multitask

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 18, 2026

Codecov Report

❌ Patch coverage is 98.46154% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 93.19%. Comparing base (f685425) to head (f7cc1e0).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepctr_torch/models/basemodel.py 93.54% 2 Missing ⚠️
deepctr_torch/callbacks.py 99.37% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #310      +/-   ##
==========================================
+ Coverage   92.56%   93.19%   +0.62%     
==========================================
  Files          30       35       +5     
  Lines        2274     2483     +209     
==========================================
+ Hits         2105     2314     +209     
  Misses        169      169              
Flag Coverage Δ
pytest 93.19% <98.46%> (+0.62%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- add dedicated callbacks unit tests covering EarlyStopping/ModelCheckpoint branches

- add DeepFM column-vector target fit regression

- add callback constructor docstrings and explicit metric imports for lint
@shenweichen shenweichen marked this pull request as ready for review April 18, 2026 10:39
@shenweichen shenweichen merged commit 70f31d7 into master Apr 18, 2026
13 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment