diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85db633af271..350c2ef378e0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1031,7 +1031,6 @@ common-files: &common_files | tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py | tests/unittest/_torch/sampler/test_beam_search.py | tests/unittest/_torch/sampler/test_best_of_n.py | - tests/unittest/_torch/sampler/test_torch_multi_arange.py | tests/unittest/_torch/sampler/test_trtllm_sampler.py | tests/unittest/_torch/speculative/test_draft_target.py | tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py | @@ -1046,6 +1045,7 @@ common-files: &common_files | tests/unittest/_torch/speculative/test_torch_rejection_sampling.py | tests/unittest/_torch/speculative/test_user_provided.py | tests/unittest/_torch/test_connector.py | + tests/unittest/_torch/test_torch_multi_arange.py | tests/unittest/_torch/thop/parallel/deep_gemm_tests.py | tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py | tests/unittest/_torch/thop/parallel/test_cublas_mm.py | @@ -2406,7 +2406,6 @@ legacy-files: &legacy_files | tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py | tests/unittest/_torch/sampler/test_beam_search.py | tests/unittest/_torch/sampler/test_best_of_n.py | - tests/unittest/_torch/sampler/test_torch_multi_arange.py | tests/unittest/_torch/sampler/test_trtllm_sampler.py | tests/unittest/_torch/speculative/test_draft_target.py | tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py | @@ -2421,6 +2420,7 @@ legacy-files: &legacy_files | tests/unittest/_torch/speculative/test_torch_rejection_sampling.py | tests/unittest/_torch/speculative/test_user_provided.py | tests/unittest/_torch/test_connector.py | + tests/unittest/_torch/test_torch_multi_arange.py | tests/unittest/_torch/thop/parallel/deep_gemm_tests.py | tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py | tests/unittest/_torch/thop/parallel/test_cublas_mm.py | diff --git a/legacy-files.txt b/legacy-files.txt index e646b59a9b31..5075c085b207 100644 --- a/legacy-files.txt +++ b/legacy-files.txt @@ -1023,7 +1023,6 @@ tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py tests/unittest/_torch/sampler/test_beam_search.py tests/unittest/_torch/sampler/test_best_of_n.py -tests/unittest/_torch/sampler/test_torch_multi_arange.py tests/unittest/_torch/sampler/test_trtllm_sampler.py tests/unittest/_torch/speculative/test_draft_target.py tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py @@ -1038,6 +1037,7 @@ tests/unittest/_torch/speculative/test_spec_gate.py tests/unittest/_torch/speculative/test_torch_rejection_sampling.py tests/unittest/_torch/speculative/test_user_provided.py tests/unittest/_torch/test_connector.py +tests/unittest/_torch/test_torch_multi_arange.py tests/unittest/_torch/thop/parallel/deep_gemm_tests.py tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py tests/unittest/_torch/thop/parallel/test_cublas_mm.py diff --git a/pyproject.toml b/pyproject.toml index 6d95d3204a9c..ffa66004f666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1081,7 +1081,6 @@ exclude = [ "tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py", "tests/unittest/_torch/sampler/test_beam_search.py", "tests/unittest/_torch/sampler/test_best_of_n.py", - "tests/unittest/_torch/sampler/test_torch_multi_arange.py", "tests/unittest/_torch/sampler/test_trtllm_sampler.py", "tests/unittest/_torch/speculative/test_draft_target.py", "tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py", @@ -1096,6 +1095,7 @@ exclude = [ "tests/unittest/_torch/speculative/test_torch_rejection_sampling.py", "tests/unittest/_torch/speculative/test_user_provided.py", "tests/unittest/_torch/test_connector.py", + "tests/unittest/_torch/test_torch_multi_arange.py", "tests/unittest/_torch/thop/parallel/deep_gemm_tests.py", "tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py", "tests/unittest/_torch/thop/parallel/test_cublas_mm.py", diff --git a/ruff-legacy.toml b/ruff-legacy.toml index c261c908abb1..6d3c84b0bbbc 100644 --- a/ruff-legacy.toml +++ b/ruff-legacy.toml @@ -1040,7 +1040,6 @@ include = [ "tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py", "tests/unittest/_torch/sampler/test_beam_search.py", "tests/unittest/_torch/sampler/test_best_of_n.py", - "tests/unittest/_torch/sampler/test_torch_multi_arange.py", "tests/unittest/_torch/sampler/test_trtllm_sampler.py", "tests/unittest/_torch/speculative/test_draft_target.py", "tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py", @@ -1055,6 +1054,7 @@ include = [ "tests/unittest/_torch/speculative/test_torch_rejection_sampling.py", "tests/unittest/_torch/speculative/test_user_provided.py", "tests/unittest/_torch/test_connector.py", + "tests/unittest/_torch/test_torch_multi_arange.py", "tests/unittest/_torch/thop/parallel/deep_gemm_tests.py", "tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py", "tests/unittest/_torch/thop/parallel/test_cublas_mm.py", diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 77861d3c786e..059aa49f14c7 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -16,14 +16,13 @@ from flashinfer.jit.core import check_cuda_arch from typing_extensions import Self -from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange from tensorrt_llm._utils import nvtx_range from tensorrt_llm.functional import AttentionMaskType from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig from ..metadata import KVCacheParams -from ..utils import get_global_attrs, get_model_extra_attrs +from ..utils import get_global_attrs, get_model_extra_attrs, torch_multi_arange from .interface import (AttentionBackend, AttentionForwardArgs, AttentionInputType, AttentionMetadata, CustomAttentionMask, MLAParams, PredefinedAttentionMask, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 3a680cbb919a..157283aa01a2 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -14,7 +14,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub -from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange +from tensorrt_llm._torch.utils import torch_multi_arange from tensorrt_llm._utils import (is_trace_enabled, maybe_pin_memory, nvtx_range, prefer_pinned, release_gc, torch_dtype_to_str, trace_func) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 1dae2b5803fc..9e98123a8f38 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -83,6 +83,7 @@ from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..speculative.interface import get_force_num_accepted_tokens from ..speculative.spec_tree_manager import SpecTreeManager +from ..utils import torch_multi_arange from .finish_reason import FinishedState from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length from .resource_manager import ResourceManager, ResourceManagerType @@ -100,7 +101,6 @@ resolve_sampling_strategy, sample, sample_rejected, - torch_multi_arange, ) from .scheduler import ScheduledRequests diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 6df7048c2c68..d04afcc27419 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -614,95 +614,6 @@ def sample_grouped_strategies( ) -class _AcceptSyncCompute: - pass - - -ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute() - - -# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the -# suggestion to consider torch.nested. -def torch_multi_arange( - ends: torch.Tensor, - *, - output_length: int | _AcceptSyncCompute, - starts: Optional[torch.Tensor] = None, - steps: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]). - - Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are - silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0. - - Provide 'output_length' to avoid synchronization when using device tensors or pass - `ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors) - or when tensors are known to reside on the host. - """ - if steps is not None: - assert ends.dtype == steps.dtype - assert ends.shape == steps.shape - assert ends.device == steps.device - if starts is not None: - assert ends.dtype == starts.dtype - assert ends.shape == starts.shape - assert ends.device == starts.device - output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length - - if ends.numel() == 0: - return ends.clone() - - # This algorithm combines torch.repeat_interleaved() and torch.cumsum() to - # construct the result. - # - # 1. Given N ranges (characterized by starts, ends, steps), construct a sequence - # of 2N numbers, in which the non-overlapping pairs of consecutive numbers - # correspond to the ranges. For a given range, the pair (a, b) is chosen such - # that upon torch.cumsum() application 'a' turns the last element of the - # preceding range into the start element for the current range and 'b' is - # simply the step size for the current range. - # - repeats = ends # number of elements in each range - if starts is not None: - repeats = repeats.clone() - repeats -= starts - if steps is not None: - repeats *= steps.sign() - steps_abs = steps.abs() - repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor") - repeats = repeats.clip(min=0) # ignore invalid ranges - range_ends = repeats - 1 # last element in each range - if steps is not None: - range_ends *= steps - if starts is not None: - range_ends += starts - prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0) - prev_range_ends[0].fill_(0) - ones = torch.ones((), dtype=ends.dtype, device=ends.device) - zeros = torch.zeros((), dtype=ends.dtype, device=ends.device) - if steps is None: - steps = ones.broadcast_to(ends.shape) - jumps = -prev_range_ends # delta from one range to the next - if starts is not None: - jumps += starts - # NB: Apply correction for empty ranges - jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype) - jumps += jumps_corrections - seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1) - # - # 2. Construct output via torch.repeat_interleave() and torch.cumsum() - # NB: For a resulting empty range, repeats - 1 == -1. In this case, we - # should set repeats for delta and increment both to 0 instead. - jump_repeats = torch.where(repeats == 0, zeros, ones) - step_repeats = torch.where(repeats == 0, zeros, repeats - 1) - seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view( - -1 - ) - seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg) - seq = seq.cumsum(0, dtype=ends.dtype) - return seq - - class _Fusions: @staticmethod @torch.compile(dynamic=None, fullgraph=True) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 4e9c92c9ba76..0102379d920c 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -520,3 +520,96 @@ def replace_parameter_and_save_metadata( raise ValueError(f"Invalid type {type(new_param)} for new_param") module.register_parameter(param_name, saved_param) + + +class _AcceptSyncCompute: + pass + + +ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute() + + +# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the +# suggestion to consider torch.nested. +def torch_multi_arange( + ends: torch.Tensor, + *, + output_length: int | _AcceptSyncCompute, + starts: torch.Tensor | None = None, + steps: torch.Tensor | None = None, +) -> torch.Tensor: + """Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]). + + Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are + silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0. + + Provide 'output_length' to avoid synchronization when using device tensors or pass + `ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors) + or when tensors are known to reside on the host. + """ + if steps is not None: + assert ends.dtype == steps.dtype + assert ends.shape == steps.shape + assert ends.device == steps.device + if starts is not None: + assert ends.dtype == starts.dtype + assert ends.shape == starts.shape + assert ends.device == starts.device + output_length_arg = None if isinstance( + output_length, _AcceptSyncCompute) else output_length + + if ends.numel() == 0: + return ends.clone() + + # This algorithm combines torch.repeat_interleaved() and torch.cumsum() to + # construct the result. + # + # 1. Given N ranges (characterized by starts, ends, steps), construct a sequence + # of 2N numbers, in which the non-overlapping pairs of consecutive numbers + # correspond to the ranges. For a given range, the pair (a, b) is chosen such + # that upon torch.cumsum() application 'a' turns the last element of the + # preceding range into the start element for the current range and 'b' is + # simply the step size for the current range. + # + repeats = ends # number of elements in each range + if starts is not None: + repeats = repeats.clone() + repeats -= starts + if steps is not None: + repeats *= steps.sign() + steps_abs = steps.abs() + repeats = (repeats + steps_abs - 1).div(steps_abs, + rounding_mode="floor") + repeats = repeats.clip(min=0) # ignore invalid ranges + range_ends = repeats - 1 # last element in each range + if steps is not None: + range_ends *= steps + if starts is not None: + range_ends += starts + prev_range_ends = range_ends.roll( + 1) # last element in preceding range (or 0) + prev_range_ends[0].fill_(0) + ones = torch.ones((), dtype=ends.dtype, device=ends.device) + zeros = torch.zeros((), dtype=ends.dtype, device=ends.device) + if steps is None: + steps = ones.broadcast_to(ends.shape) + jumps = -prev_range_ends # delta from one range to the next + if starts is not None: + jumps += starts + # NB: Apply correction for empty ranges + jumps_corrections = torch.where(repeats == 0, jumps, + zeros).cumsum(0, dtype=ends.dtype) + jumps += jumps_corrections + seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1) + # + # 2. Construct output via torch.repeat_interleave() and torch.cumsum() + # NB: For a resulting empty range, repeats - 1 == -1. In this case, we + # should set repeats for delta and increment both to 0 instead. + jump_repeats = torch.where(repeats == 0, zeros, ones) + step_repeats = torch.where(repeats == 0, zeros, repeats - 1) + seq_repeats = torch.cat( + (jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), + dim=1).view(-1) + seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg) + seq = seq.cumsum(0, dtype=ends.dtype) + return seq diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 3e61e02ee1f6..d4b8ea7c62dc 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -15,7 +15,7 @@ l0_a10: tests: # ------------- PyTorch tests --------------- - unittest/_torch/sampler/test_torch_sampler.py - - unittest/_torch/sampler/test_torch_multi_arange.py + - unittest/_torch/test_torch_multi_arange.py - unittest/utils/test_util.py - unittest/utils/test_logger.py - unittest/_torch/test_model_config.py diff --git a/tests/unittest/_torch/sampler/test_torch_multi_arange.py b/tests/unittest/_torch/test_torch_multi_arange.py similarity index 96% rename from tests/unittest/_torch/sampler/test_torch_multi_arange.py rename to tests/unittest/_torch/test_torch_multi_arange.py index a05e059b6b50..44d32e0d0e86 100644 --- a/tests/unittest/_torch/sampler/test_torch_multi_arange.py +++ b/tests/unittest/_torch/test_torch_multi_arange.py @@ -21,8 +21,7 @@ import torch from utils.util import assert_no_cuda_sync, force_ampere -from tensorrt_llm._torch.pyexecutor.sampling_utils import (ACCEPT_SYNC_COMPUTE, - torch_multi_arange) +from tensorrt_llm._torch.utils import ACCEPT_SYNC_COMPUTE, torch_multi_arange BASE_CASES = [ (None, [], None, []),