Commit 3ad4f4f
authored
[Fix] Re-expand target_input on OOM in get_max_batch_size (NVIDIA#1374)
## Summary
- `get_max_batch_size` halved `target_data_batch` on
`torch.cuda.OutOfMemoryError` but never rebuilt `target_input`, so each
retry re-fed the same too-large tensor — the retry loop was effectively
a no-op.
- Refactor the expand logic into an `_expand_to(batch)` helper, rebuild
`target_input` after halving, and call `torch.cuda.empty_cache()`
between attempts.
## Test plan
- [x] New unit test `test_get_max_batch_size_oom_retry_shrinks_input`
mocks `torch.cuda.*` and asserts the second retry receives the halved
tensor (shapes seen: `[1, 10, 5]`, regulated result `4`).
- [x] `pytest tests/unit/torch/utils/test_dataset_utils.py` — 14/14 pass
(skipping the network-only minipile test).
- [x] `pre-commit` (ruff, mypy, bandit, license headers) clean on
commit.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Enhanced GPU memory management during batch size detection. When
out-of-memory errors occur during the initial probing phase, the system
now properly adapts input tensors to smaller batch sizes and clears GPU
cache before retry attempts, resulting in more reliable recovery and
stable batch sizing across diverse hardware environments.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>1 parent bb08094 commit 3ad4f4f
2 files changed
Lines changed: 62 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
519 | 519 | | |
520 | 520 | | |
521 | 521 | | |
522 | | - | |
523 | | - | |
524 | | - | |
525 | | - | |
526 | | - | |
527 | | - | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
528 | 532 | | |
529 | 533 | | |
530 | 534 | | |
| |||
535 | 539 | | |
536 | 540 | | |
537 | 541 | | |
| 542 | + | |
| 543 | + | |
538 | 544 | | |
539 | 545 | | |
540 | 546 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
16 | | - | |
| 16 | + | |
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
27 | 28 | | |
28 | 29 | | |
29 | 30 | | |
| |||
231 | 232 | | |
232 | 233 | | |
233 | 234 | | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
234 | 283 | | |
235 | 284 | | |
236 | 285 | | |
| |||
0 commit comments