|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import pytest |
| 3 | + |
| 4 | +from lmdeploy.messages import PytorchEngineConfig |
| 5 | +from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner |
| 6 | +from lmdeploy.pytorch.config import CacheConfig |
| 7 | +from lmdeploy.pytorch.engine.config_builder import ConfigBuilder |
| 8 | + |
| 9 | + |
| 10 | +def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None): |
| 11 | + return CacheConfig(max_batches=max_batches, |
| 12 | + block_size=64, |
| 13 | + num_cpu_blocks=0, |
| 14 | + num_gpu_blocks=1, |
| 15 | + cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes) |
| 16 | + |
| 17 | + |
| 18 | +def test_custom_capture_batch_sizes_include_max_batch_size(): |
| 19 | + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) |
| 20 | + |
| 21 | + engine_config = ConfigBuilder.update_engine_config(engine_config) |
| 22 | + |
| 23 | + assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8] |
| 24 | + |
| 25 | + |
| 26 | +def test_cache_config_normalizes_capture_batch_sizes(): |
| 27 | + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) |
| 28 | + |
| 29 | + assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8] |
| 30 | + |
| 31 | + |
| 32 | +@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]]) |
| 33 | +def test_invalid_capture_batch_sizes_raise(sizes): |
| 34 | + with pytest.raises(AssertionError): |
| 35 | + _cache_config(max_batches=8, cudagraph_capture_batch_sizes=sizes) |
| 36 | + |
| 37 | + |
| 38 | +def test_capture_batch_size_miss_raises(): |
| 39 | + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4]) |
| 40 | + engine_config = ConfigBuilder.update_engine_config(engine_config) |
| 41 | + runner = object.__new__(CUDAGraphRunner) |
| 42 | + runner.cache_config = ConfigBuilder.build_cache_config(engine_config) |
| 43 | + |
| 44 | + assert runner._get_capture_tokens(5) == 8 |
| 45 | + with pytest.raises(AssertionError): |
| 46 | + runner._get_capture_tokens(9) |
| 47 | + |
| 48 | + |
| 49 | +def test_graph_runner_defensively_normalizes_capture_batch_sizes(): |
| 50 | + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 8]) |
| 51 | + cache_config.cudagraph_capture_batch_sizes = [4, 1, 4, 16] |
| 52 | + runner = object.__new__(CUDAGraphRunner) |
| 53 | + runner.cache_config = cache_config |
| 54 | + |
| 55 | + assert runner.get_capture_batch_sizes() == [1, 4, 8] |
0 commit comments