Skip to content

Commit 3ad4f4f

Browse files
authored
[Fix] Re-expand target_input on OOM in get_max_batch_size (NVIDIA#1374)
## Summary - `get_max_batch_size` halved `target_data_batch` on `torch.cuda.OutOfMemoryError` but never rebuilt `target_input`, so each retry re-fed the same too-large tensor — the retry loop was effectively a no-op. - Refactor the expand logic into an `_expand_to(batch)` helper, rebuild `target_input` after halving, and call `torch.cuda.empty_cache()` between attempts. ## Test plan - [x] New unit test `test_get_max_batch_size_oom_retry_shrinks_input` mocks `torch.cuda.*` and asserts the second retry receives the halved tensor (shapes seen: `[1, 10, 5]`, regulated result `4`). - [x] `pytest tests/unit/torch/utils/test_dataset_utils.py` — 14/14 pass (skipping the network-only minipile test). - [x] `pre-commit` (ruff, mypy, bandit, license headers) clean on commit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enhanced GPU memory management during batch size detection. When out-of-memory errors occur during the initial probing phase, the system now properly adapts input tensors to smaller batch sizes and clears GPU cache before retry attempts, resulting in more reliable recovery and stable batch sizing across diverse hardware environments. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent bb08094 commit 3ad4f4f

2 files changed

Lines changed: 62 additions & 7 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,16 @@ def _get_free_gpu_mem():
519519
target_data_batch = 1 # pragma: no cover
520520
else:
521521
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
522-
target_input = sample_input_single_batch.expand(
523-
[
524-
target_data_batch if index == 0 else dim
525-
for index, dim in enumerate(sample_input_single_batch.shape)
526-
]
527-
)
522+
523+
def _expand_to(batch: int) -> torch.Tensor:
524+
return sample_input_single_batch.expand(
525+
[
526+
batch if index == 0 else dim
527+
for index, dim in enumerate(sample_input_single_batch.shape)
528+
]
529+
)
530+
531+
target_input = _expand_to(target_data_batch)
528532

529533
# For some models on multi GPU, we observe the memory per batch is not a constant.
530534
# So we just test the target batch size and make sure we do not go OOM.
@@ -535,6 +539,8 @@ def _get_free_gpu_mem():
535539
break
536540
except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path
537541
target_data_batch = target_data_batch // 2 # pragma: no cover
542+
target_input = _expand_to(target_data_batch) # pragma: no cover
543+
torch.cuda.empty_cache() # pragma: no cover
538544

539545
# Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64
540546
if target_data_batch < 2:

tests/unit/torch/utils/test_dataset_utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from unittest.mock import Mock
16+
from unittest.mock import Mock, patch
1717

1818
import pytest
1919
import torch
@@ -24,6 +24,7 @@
2424
_forward_loop,
2525
_process_batch,
2626
get_dataset_samples,
27+
get_max_batch_size,
2728
)
2829

2930

@@ -231,6 +232,54 @@ def test_disable_use_cache_restores_on_exception():
231232
assert model.config.use_cache is True
232233

233234

235+
def test_get_max_batch_size_oom_retry_shrinks_input():
236+
"""On OOM, target_input must be re-expanded to the halved batch size.
237+
238+
Regression test: previously target_input was built once and never shrunk,
239+
so OOM retries kept feeding the same too-large tensor to infer_method.
240+
"""
241+
seq_len = 8
242+
sample_input = torch.ones((1, seq_len), dtype=torch.int32)
243+
244+
seen_batch_sizes: list[int] = []
245+
246+
def fake_forward(x):
247+
seen_batch_sizes.append(x.shape[0])
248+
# First call is the single-batch probe — succeeds.
249+
# Second call is the target-batch attempt — OOMs.
250+
# Third call (after halving) — succeeds.
251+
if len(seen_batch_sizes) == 2:
252+
raise torch.cuda.OutOfMemoryError
253+
254+
model = Mock(spec=torch.nn.Module)
255+
model.forward = fake_forward
256+
model.__class__.__name__ = "DummyModel" # not enc/dec
257+
258+
free_before = 1000
259+
free_after = 900 # 100 bytes per batch -> target = 1000/100 = 10
260+
261+
device_props = Mock()
262+
device_props.total_memory = 10**12
263+
264+
with (
265+
patch("torch.cuda.empty_cache"),
266+
patch("torch.cuda.get_device_properties", return_value=device_props),
267+
patch("torch.cuda.device_count", return_value=1),
268+
patch("torch.cuda.mem_get_info", side_effect=[(free_before, 0), (free_after, 0)]),
269+
patch("torch.cuda.max_memory_allocated", side_effect=[0, 0]),
270+
):
271+
result = get_max_batch_size(
272+
model,
273+
max_sample_length=seq_len,
274+
sample_input_single_batch=sample_input,
275+
)
276+
277+
# Forward calls: probe(1), retry-at-target(10), retry-after-halve(5)
278+
assert seen_batch_sizes == [1, 10, 5]
279+
# Final batch is 5 -> regulated to 4 (5 // 4 * 4 = 4).
280+
assert result == 4
281+
282+
234283
@pytest.mark.parametrize("test_local_path", [True, False])
235284
def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path):
236285
pytest.importorskip("datasets")

0 commit comments

Comments
 (0)