Skip to content

Commit faad7dc

Browse files
committed
chore: relocate torch_multi_arange
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 800c7ee commit faad7dc

11 files changed

Lines changed: 103 additions & 101 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,6 @@ common-files: &common_files |
10121012
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
10131013
tests/unittest/_torch/sampler/test_beam_search.py |
10141014
tests/unittest/_torch/sampler/test_best_of_n.py |
1015-
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
10161015
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
10171016
tests/unittest/_torch/speculative/test_draft_target.py |
10181017
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py |
@@ -1027,6 +1026,7 @@ common-files: &common_files |
10271026
tests/unittest/_torch/speculative/test_torch_rejection_sampling.py |
10281027
tests/unittest/_torch/speculative/test_user_provided.py |
10291028
tests/unittest/_torch/test_connector.py |
1029+
tests/unittest/_torch/test_torch_multi_arange.py |
10301030
tests/unittest/_torch/thop/parallel/deep_gemm_tests.py |
10311031
tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py |
10321032
tests/unittest/_torch/thop/parallel/test_cublas_mm.py |
@@ -2368,7 +2368,6 @@ legacy-files: &legacy_files |
23682368
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
23692369
tests/unittest/_torch/sampler/test_beam_search.py |
23702370
tests/unittest/_torch/sampler/test_best_of_n.py |
2371-
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
23722371
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
23732372
tests/unittest/_torch/speculative/test_draft_target.py |
23742373
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py |
@@ -2383,6 +2382,7 @@ legacy-files: &legacy_files |
23832382
tests/unittest/_torch/speculative/test_torch_rejection_sampling.py |
23842383
tests/unittest/_torch/speculative/test_user_provided.py |
23852384
tests/unittest/_torch/test_connector.py |
2385+
tests/unittest/_torch/test_torch_multi_arange.py |
23862386
tests/unittest/_torch/thop/parallel/deep_gemm_tests.py |
23872387
tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py |
23882388
tests/unittest/_torch/thop/parallel/test_cublas_mm.py |

legacy-files.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py
10041004
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py
10051005
tests/unittest/_torch/sampler/test_beam_search.py
10061006
tests/unittest/_torch/sampler/test_best_of_n.py
1007-
tests/unittest/_torch/sampler/test_torch_multi_arange.py
1007+
tests/unittest/_torch/test_torch_multi_arange.py
10081008
tests/unittest/_torch/sampler/test_trtllm_sampler.py
10091009
tests/unittest/_torch/speculative/test_draft_target.py
10101010
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,6 @@ exclude = [
10621062
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
10631063
"tests/unittest/_torch/sampler/test_beam_search.py",
10641064
"tests/unittest/_torch/sampler/test_best_of_n.py",
1065-
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
10661065
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
10671066
"tests/unittest/_torch/speculative/test_draft_target.py",
10681067
"tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py",
@@ -1077,6 +1076,7 @@ exclude = [
10771076
"tests/unittest/_torch/speculative/test_torch_rejection_sampling.py",
10781077
"tests/unittest/_torch/speculative/test_user_provided.py",
10791078
"tests/unittest/_torch/test_connector.py",
1079+
"tests/unittest/_torch/test_torch_multi_arange.py",
10801080
"tests/unittest/_torch/thop/parallel/deep_gemm_tests.py",
10811081
"tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py",
10821082
"tests/unittest/_torch/thop/parallel/test_cublas_mm.py",

ruff-legacy.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,6 @@ include = [
10211021
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
10221022
"tests/unittest/_torch/sampler/test_beam_search.py",
10231023
"tests/unittest/_torch/sampler/test_best_of_n.py",
1024-
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
10251024
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
10261025
"tests/unittest/_torch/speculative/test_draft_target.py",
10271026
"tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py",
@@ -1036,6 +1035,7 @@ include = [
10361035
"tests/unittest/_torch/speculative/test_torch_rejection_sampling.py",
10371036
"tests/unittest/_torch/speculative/test_user_provided.py",
10381037
"tests/unittest/_torch/test_connector.py",
1038+
"tests/unittest/_torch/test_torch_multi_arange.py",
10391039
"tests/unittest/_torch/thop/parallel/deep_gemm_tests.py",
10401040
"tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py",
10411041
"tests/unittest/_torch/thop/parallel/test_cublas_mm.py",

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010
from flashinfer.jit.core import check_cuda_arch
1111
from typing_extensions import Self
1212

13-
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
1413
from tensorrt_llm._utils import nvtx_range
1514
from tensorrt_llm.functional import AttentionMaskType
1615
from tensorrt_llm.logger import logger
1716
from tensorrt_llm.models.modeling_utils import QuantConfig
1817

1918
from ..metadata import KVCacheParams
20-
from ..utils import get_global_attrs, get_model_extra_attrs
19+
from ..utils import get_global_attrs, get_model_extra_attrs, torch_multi_arange
2120
from .interface import (AttentionBackend, AttentionForwardArgs,
2221
AttentionInputType, AttentionMetadata,
2322
CustomAttentionMask, MLAParams, PredefinedAttentionMask,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
8484
from ..speculative.interface import get_force_num_accepted_tokens
8585
from ..speculative.spec_tree_manager import SpecTreeManager
86+
from ..utils import torch_multi_arange
8687
from .finish_reason import FinishedState
8788
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
8889
from .resource_manager import ResourceManager, ResourceManagerType
@@ -100,7 +101,6 @@
100101
resolve_sampling_strategy,
101102
sample,
102103
sample_rejected,
103-
torch_multi_arange,
104104
)
105105
from .scheduler import ScheduledRequests
106106

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -614,95 +614,6 @@ def sample_grouped_strategies(
614614
)
615615

616616

617-
class _AcceptSyncCompute:
618-
pass
619-
620-
621-
ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()
622-
623-
624-
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
625-
# suggestion to consider torch.nested.
626-
def torch_multi_arange(
627-
ends: torch.Tensor,
628-
*,
629-
output_length: int | _AcceptSyncCompute,
630-
starts: Optional[torch.Tensor] = None,
631-
steps: Optional[torch.Tensor] = None,
632-
) -> torch.Tensor:
633-
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
634-
635-
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
636-
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
637-
638-
Provide 'output_length' to avoid synchronization when using device tensors or pass
639-
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
640-
or when tensors are known to reside on the host.
641-
"""
642-
if steps is not None:
643-
assert ends.dtype == steps.dtype
644-
assert ends.shape == steps.shape
645-
assert ends.device == steps.device
646-
if starts is not None:
647-
assert ends.dtype == starts.dtype
648-
assert ends.shape == starts.shape
649-
assert ends.device == starts.device
650-
output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length
651-
652-
if ends.numel() == 0:
653-
return ends.clone()
654-
655-
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
656-
# construct the result.
657-
#
658-
# 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
659-
# of 2N numbers, in which the non-overlapping pairs of consecutive numbers
660-
# correspond to the ranges. For a given range, the pair (a, b) is chosen such
661-
# that upon torch.cumsum() application 'a' turns the last element of the
662-
# preceding range into the start element for the current range and 'b' is
663-
# simply the step size for the current range.
664-
#
665-
repeats = ends # number of elements in each range
666-
if starts is not None:
667-
repeats = repeats.clone()
668-
repeats -= starts
669-
if steps is not None:
670-
repeats *= steps.sign()
671-
steps_abs = steps.abs()
672-
repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor")
673-
repeats = repeats.clip(min=0) # ignore invalid ranges
674-
range_ends = repeats - 1 # last element in each range
675-
if steps is not None:
676-
range_ends *= steps
677-
if starts is not None:
678-
range_ends += starts
679-
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
680-
prev_range_ends[0].fill_(0)
681-
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
682-
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
683-
if steps is None:
684-
steps = ones.broadcast_to(ends.shape)
685-
jumps = -prev_range_ends # delta from one range to the next
686-
if starts is not None:
687-
jumps += starts
688-
# NB: Apply correction for empty ranges
689-
jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype)
690-
jumps += jumps_corrections
691-
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
692-
#
693-
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
694-
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
695-
# should set repeats for delta and increment both to 0 instead.
696-
jump_repeats = torch.where(repeats == 0, zeros, ones)
697-
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
698-
seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view(
699-
-1
700-
)
701-
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
702-
seq = seq.cumsum(0, dtype=ends.dtype)
703-
return seq
704-
705-
706617
class _Fusions:
707618
@staticmethod
708619
@torch.compile(dynamic=None, fullgraph=True)

tensorrt_llm/_torch/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,96 @@ def replace_parameter_and_save_metadata(
520520
raise ValueError(f"Invalid type {type(new_param)} for new_param")
521521

522522
module.register_parameter(param_name, saved_param)
523+
524+
525+
class _AcceptSyncCompute:
526+
pass
527+
528+
529+
ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()
530+
531+
532+
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
533+
# suggestion to consider torch.nested.
534+
def torch_multi_arange(
535+
ends: torch.Tensor,
536+
*,
537+
output_length: int | _AcceptSyncCompute,
538+
starts: torch.Tensor | None = None,
539+
steps: torch.Tensor | None = None,
540+
) -> torch.Tensor:
541+
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
542+
543+
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
544+
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
545+
546+
Provide 'output_length' to avoid synchronization when using device tensors or pass
547+
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
548+
or when tensors are known to reside on the host.
549+
"""
550+
if steps is not None:
551+
assert ends.dtype == steps.dtype
552+
assert ends.shape == steps.shape
553+
assert ends.device == steps.device
554+
if starts is not None:
555+
assert ends.dtype == starts.dtype
556+
assert ends.shape == starts.shape
557+
assert ends.device == starts.device
558+
output_length_arg = None if isinstance(
559+
output_length, _AcceptSyncCompute) else output_length
560+
561+
if ends.numel() == 0:
562+
return ends.clone()
563+
564+
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
565+
# construct the result.
566+
#
567+
# 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
568+
# of 2N numbers, in which the non-overlapping pairs of consecutive numbers
569+
# correspond to the ranges. For a given range, the pair (a, b) is chosen such
570+
# that upon torch.cumsum() application 'a' turns the last element of the
571+
# preceding range into the start element for the current range and 'b' is
572+
# simply the step size for the current range.
573+
#
574+
repeats = ends # number of elements in each range
575+
if starts is not None:
576+
repeats = repeats.clone()
577+
repeats -= starts
578+
if steps is not None:
579+
repeats *= steps.sign()
580+
steps_abs = steps.abs()
581+
repeats = (repeats + steps_abs - 1).div(steps_abs,
582+
rounding_mode="floor")
583+
repeats = repeats.clip(min=0) # ignore invalid ranges
584+
range_ends = repeats - 1 # last element in each range
585+
if steps is not None:
586+
range_ends *= steps
587+
if starts is not None:
588+
range_ends += starts
589+
prev_range_ends = range_ends.roll(
590+
1) # last element in preceding range (or 0)
591+
prev_range_ends[0].fill_(0)
592+
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
593+
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
594+
if steps is None:
595+
steps = ones.broadcast_to(ends.shape)
596+
jumps = -prev_range_ends # delta from one range to the next
597+
if starts is not None:
598+
jumps += starts
599+
# NB: Apply correction for empty ranges
600+
jumps_corrections = torch.where(repeats == 0, jumps,
601+
zeros).cumsum(0, dtype=ends.dtype)
602+
jumps += jumps_corrections
603+
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
604+
#
605+
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
606+
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
607+
# should set repeats for delta and increment both to 0 instead.
608+
jump_repeats = torch.where(repeats == 0, zeros, ones)
609+
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
610+
seq_repeats = torch.cat(
611+
(jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)),
612+
dim=1).view(-1)
613+
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
614+
seq = seq.cumsum(0, dtype=ends.dtype)
615+
return seq

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tqdm import tqdm
1717
from transformers import PreTrainedTokenizerBase
1818

19-
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
19+
from tensorrt_llm._torch.utils import torch_multi_arange
2020
from tensorrt_llm._utils import mpi_disabled
2121
from tensorrt_llm.inputs.multimodal import (DisaggPrefillMultimodalInputs,
2222
MultimodalParams)

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ l0_a10:
1515
tests:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/sampler/test_torch_sampler.py
18-
- unittest/_torch/sampler/test_torch_multi_arange.py
18+
- unittest/_torch/test_torch_multi_arange.py
1919
- unittest/utils/test_util.py
2020
- unittest/utils/test_logger.py
2121
- unittest/_torch/test_model_config.py

0 commit comments

Comments
 (0)