Skip to content

Commit 6cf0c06

Browse files
committed
chore: relocate torch_multi_arange
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 2708009 commit 6cf0c06

11 files changed

Lines changed: 103 additions & 101 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,6 @@ common-files: &common_files |
10311031
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
10321032
tests/unittest/_torch/sampler/test_beam_search.py |
10331033
tests/unittest/_torch/sampler/test_best_of_n.py |
1034-
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
10351034
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
10361035
tests/unittest/_torch/speculative/test_draft_target.py |
10371036
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py |
@@ -1046,6 +1045,7 @@ common-files: &common_files |
10461045
tests/unittest/_torch/speculative/test_torch_rejection_sampling.py |
10471046
tests/unittest/_torch/speculative/test_user_provided.py |
10481047
tests/unittest/_torch/test_connector.py |
1048+
tests/unittest/_torch/test_torch_multi_arange.py |
10491049
tests/unittest/_torch/thop/parallel/deep_gemm_tests.py |
10501050
tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py |
10511051
tests/unittest/_torch/thop/parallel/test_cublas_mm.py |
@@ -2406,7 +2406,6 @@ legacy-files: &legacy_files |
24062406
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
24072407
tests/unittest/_torch/sampler/test_beam_search.py |
24082408
tests/unittest/_torch/sampler/test_best_of_n.py |
2409-
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
24102409
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
24112410
tests/unittest/_torch/speculative/test_draft_target.py |
24122411
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py |
@@ -2421,6 +2420,7 @@ legacy-files: &legacy_files |
24212420
tests/unittest/_torch/speculative/test_torch_rejection_sampling.py |
24222421
tests/unittest/_torch/speculative/test_user_provided.py |
24232422
tests/unittest/_torch/test_connector.py |
2423+
tests/unittest/_torch/test_torch_multi_arange.py |
24242424
tests/unittest/_torch/thop/parallel/deep_gemm_tests.py |
24252425
tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py |
24262426
tests/unittest/_torch/thop/parallel/test_cublas_mm.py |

legacy-files.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py
10231023
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py
10241024
tests/unittest/_torch/sampler/test_beam_search.py
10251025
tests/unittest/_torch/sampler/test_best_of_n.py
1026-
tests/unittest/_torch/sampler/test_torch_multi_arange.py
1026+
tests/unittest/_torch/test_torch_multi_arange.py
10271027
tests/unittest/_torch/sampler/test_trtllm_sampler.py
10281028
tests/unittest/_torch/speculative/test_draft_target.py
10291029
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,6 @@ exclude = [
10811081
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
10821082
"tests/unittest/_torch/sampler/test_beam_search.py",
10831083
"tests/unittest/_torch/sampler/test_best_of_n.py",
1084-
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
10851084
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
10861085
"tests/unittest/_torch/speculative/test_draft_target.py",
10871086
"tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py",
@@ -1096,6 +1095,7 @@ exclude = [
10961095
"tests/unittest/_torch/speculative/test_torch_rejection_sampling.py",
10971096
"tests/unittest/_torch/speculative/test_user_provided.py",
10981097
"tests/unittest/_torch/test_connector.py",
1098+
"tests/unittest/_torch/test_torch_multi_arange.py",
10991099
"tests/unittest/_torch/thop/parallel/deep_gemm_tests.py",
11001100
"tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py",
11011101
"tests/unittest/_torch/thop/parallel/test_cublas_mm.py",

ruff-legacy.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,6 @@ include = [
10401040
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
10411041
"tests/unittest/_torch/sampler/test_beam_search.py",
10421042
"tests/unittest/_torch/sampler/test_best_of_n.py",
1043-
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
10441043
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
10451044
"tests/unittest/_torch/speculative/test_draft_target.py",
10461045
"tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py",
@@ -1055,6 +1054,7 @@ include = [
10551054
"tests/unittest/_torch/speculative/test_torch_rejection_sampling.py",
10561055
"tests/unittest/_torch/speculative/test_user_provided.py",
10571056
"tests/unittest/_torch/test_connector.py",
1057+
"tests/unittest/_torch/test_torch_multi_arange.py",
10581058
"tests/unittest/_torch/thop/parallel/deep_gemm_tests.py",
10591059
"tests/unittest/_torch/thop/parallel/test_causal_conv1d_op.py",
10601060
"tests/unittest/_torch/thop/parallel/test_cublas_mm.py",

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
from flashinfer.jit.core import check_cuda_arch
1717
from typing_extensions import Self
1818

19-
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
2019
from tensorrt_llm._utils import nvtx_range
2120
from tensorrt_llm.functional import AttentionMaskType
2221
from tensorrt_llm.logger import logger
2322
from tensorrt_llm.models.modeling_utils import QuantConfig
2423

2524
from ..metadata import KVCacheParams
26-
from ..utils import get_global_attrs, get_model_extra_attrs
25+
from ..utils import get_global_attrs, get_model_extra_attrs, torch_multi_arange
2726
from .interface import (AttentionBackend, AttentionForwardArgs,
2827
AttentionInputType, AttentionMetadata,
2928
CustomAttentionMask, MLAParams, PredefinedAttentionMask,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch._dynamo.config
1515

1616
import tensorrt_llm.bindings.internal.userbuffers as ub
17-
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
17+
from tensorrt_llm._torch.utils import torch_multi_arange
1818
from tensorrt_llm._utils import (is_trace_enabled, maybe_pin_memory, nvtx_range,
1919
prefer_pinned, release_gc, torch_dtype_to_str,
2020
trace_func)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
8484
from ..speculative.interface import get_force_num_accepted_tokens
8585
from ..speculative.spec_tree_manager import SpecTreeManager
86+
from ..utils import torch_multi_arange
8687
from .finish_reason import FinishedState
8788
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
8889
from .resource_manager import ResourceManager, ResourceManagerType
@@ -100,7 +101,6 @@
100101
resolve_sampling_strategy,
101102
sample,
102103
sample_rejected,
103-
torch_multi_arange,
104104
)
105105
from .scheduler import ScheduledRequests
106106

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -614,95 +614,6 @@ def sample_grouped_strategies(
614614
)
615615

616616

617-
class _AcceptSyncCompute:
618-
pass
619-
620-
621-
ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()
622-
623-
624-
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
625-
# suggestion to consider torch.nested.
626-
def torch_multi_arange(
627-
ends: torch.Tensor,
628-
*,
629-
output_length: int | _AcceptSyncCompute,
630-
starts: Optional[torch.Tensor] = None,
631-
steps: Optional[torch.Tensor] = None,
632-
) -> torch.Tensor:
633-
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
634-
635-
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
636-
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
637-
638-
Provide 'output_length' to avoid synchronization when using device tensors or pass
639-
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
640-
or when tensors are known to reside on the host.
641-
"""
642-
if steps is not None:
643-
assert ends.dtype == steps.dtype
644-
assert ends.shape == steps.shape
645-
assert ends.device == steps.device
646-
if starts is not None:
647-
assert ends.dtype == starts.dtype
648-
assert ends.shape == starts.shape
649-
assert ends.device == starts.device
650-
output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length
651-
652-
if ends.numel() == 0:
653-
return ends.clone()
654-
655-
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
656-
# construct the result.
657-
#
658-
# 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
659-
# of 2N numbers, in which the non-overlapping pairs of consecutive numbers
660-
# correspond to the ranges. For a given range, the pair (a, b) is chosen such
661-
# that upon torch.cumsum() application 'a' turns the last element of the
662-
# preceding range into the start element for the current range and 'b' is
663-
# simply the step size for the current range.
664-
#
665-
repeats = ends # number of elements in each range
666-
if starts is not None:
667-
repeats = repeats.clone()
668-
repeats -= starts
669-
if steps is not None:
670-
repeats *= steps.sign()
671-
steps_abs = steps.abs()
672-
repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor")
673-
repeats = repeats.clip(min=0) # ignore invalid ranges
674-
range_ends = repeats - 1 # last element in each range
675-
if steps is not None:
676-
range_ends *= steps
677-
if starts is not None:
678-
range_ends += starts
679-
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
680-
prev_range_ends[0].fill_(0)
681-
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
682-
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
683-
if steps is None:
684-
steps = ones.broadcast_to(ends.shape)
685-
jumps = -prev_range_ends # delta from one range to the next
686-
if starts is not None:
687-
jumps += starts
688-
# NB: Apply correction for empty ranges
689-
jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype)
690-
jumps += jumps_corrections
691-
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
692-
#
693-
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
694-
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
695-
# should set repeats for delta and increment both to 0 instead.
696-
jump_repeats = torch.where(repeats == 0, zeros, ones)
697-
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
698-
seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view(
699-
-1
700-
)
701-
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
702-
seq = seq.cumsum(0, dtype=ends.dtype)
703-
return seq
704-
705-
706617
class _Fusions:
707618
@staticmethod
708619
@torch.compile(dynamic=None, fullgraph=True)

tensorrt_llm/_torch/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,96 @@ def replace_parameter_and_save_metadata(
520520
raise ValueError(f"Invalid type {type(new_param)} for new_param")
521521

522522
module.register_parameter(param_name, saved_param)
523+
524+
525+
class _AcceptSyncCompute:
526+
pass
527+
528+
529+
ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()
530+
531+
532+
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
533+
# suggestion to consider torch.nested.
534+
def torch_multi_arange(
535+
ends: torch.Tensor,
536+
*,
537+
output_length: int | _AcceptSyncCompute,
538+
starts: torch.Tensor | None = None,
539+
steps: torch.Tensor | None = None,
540+
) -> torch.Tensor:
541+
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
542+
543+
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
544+
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
545+
546+
Provide 'output_length' to avoid synchronization when using device tensors or pass
547+
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
548+
or when tensors are known to reside on the host.
549+
"""
550+
if steps is not None:
551+
assert ends.dtype == steps.dtype
552+
assert ends.shape == steps.shape
553+
assert ends.device == steps.device
554+
if starts is not None:
555+
assert ends.dtype == starts.dtype
556+
assert ends.shape == starts.shape
557+
assert ends.device == starts.device
558+
output_length_arg = None if isinstance(
559+
output_length, _AcceptSyncCompute) else output_length
560+
561+
if ends.numel() == 0:
562+
return ends.clone()
563+
564+
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
565+
# construct the result.
566+
#
567+
# 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
568+
# of 2N numbers, in which the non-overlapping pairs of consecutive numbers
569+
# correspond to the ranges. For a given range, the pair (a, b) is chosen such
570+
# that upon torch.cumsum() application 'a' turns the last element of the
571+
# preceding range into the start element for the current range and 'b' is
572+
# simply the step size for the current range.
573+
#
574+
repeats = ends # number of elements in each range
575+
if starts is not None:
576+
repeats = repeats.clone()
577+
repeats -= starts
578+
if steps is not None:
579+
repeats *= steps.sign()
580+
steps_abs = steps.abs()
581+
repeats = (repeats + steps_abs - 1).div(steps_abs,
582+
rounding_mode="floor")
583+
repeats = repeats.clip(min=0) # ignore invalid ranges
584+
range_ends = repeats - 1 # last element in each range
585+
if steps is not None:
586+
range_ends *= steps
587+
if starts is not None:
588+
range_ends += starts
589+
prev_range_ends = range_ends.roll(
590+
1) # last element in preceding range (or 0)
591+
prev_range_ends[0].fill_(0)
592+
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
593+
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
594+
if steps is None:
595+
steps = ones.broadcast_to(ends.shape)
596+
jumps = -prev_range_ends # delta from one range to the next
597+
if starts is not None:
598+
jumps += starts
599+
# NB: Apply correction for empty ranges
600+
jumps_corrections = torch.where(repeats == 0, jumps,
601+
zeros).cumsum(0, dtype=ends.dtype)
602+
jumps += jumps_corrections
603+
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
604+
#
605+
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
606+
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
607+
# should set repeats for delta and increment both to 0 instead.
608+
jump_repeats = torch.where(repeats == 0, zeros, ones)
609+
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
610+
seq_repeats = torch.cat(
611+
(jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)),
612+
dim=1).view(-1)
613+
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
614+
seq = seq.cumsum(0, dtype=ends.dtype)
615+
return seq

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ l0_a10:
1515
tests:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/sampler/test_torch_sampler.py
18-
- unittest/_torch/sampler/test_torch_multi_arange.py
18+
- unittest/_torch/test_torch_multi_arange.py
1919
- unittest/utils/test_util.py
2020
- unittest/utils/test_logger.py
2121
- unittest/_torch/test_model_config.py

0 commit comments

Comments
 (0)