@@ -520,3 +520,96 @@ def replace_parameter_and_save_metadata(
520520 raise ValueError (f"Invalid type { type (new_param )} for new_param" )
521521
522522 module .register_parameter (param_name , saved_param )
523+
524+
525+ class _AcceptSyncCompute :
526+ pass
527+
528+
529+ ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute ()
530+
531+
532+ # Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
533+ # suggestion to consider torch.nested.
534+ def torch_multi_arange (
535+ ends : torch .Tensor ,
536+ * ,
537+ output_length : int | _AcceptSyncCompute ,
538+ starts : torch .Tensor | None = None ,
539+ steps : torch .Tensor | None = None ,
540+ ) -> torch .Tensor :
541+ """Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
542+
543+ Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
544+ silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
545+
546+ Provide 'output_length' to avoid synchronization when using device tensors or pass
547+ `ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
548+ or when tensors are known to reside on the host.
549+ """
550+ if steps is not None :
551+ assert ends .dtype == steps .dtype
552+ assert ends .shape == steps .shape
553+ assert ends .device == steps .device
554+ if starts is not None :
555+ assert ends .dtype == starts .dtype
556+ assert ends .shape == starts .shape
557+ assert ends .device == starts .device
558+ output_length_arg = None if isinstance (
559+ output_length , _AcceptSyncCompute ) else output_length
560+
561+ if ends .numel () == 0 :
562+ return ends .clone ()
563+
564+ # This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
565+ # construct the result.
566+ #
567+ # 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
568+ # of 2N numbers, in which the non-overlapping pairs of consecutive numbers
569+ # correspond to the ranges. For a given range, the pair (a, b) is chosen such
570+ # that upon torch.cumsum() application 'a' turns the last element of the
571+ # preceding range into the start element for the current range and 'b' is
572+ # simply the step size for the current range.
573+ #
574+ repeats = ends # number of elements in each range
575+ if starts is not None :
576+ repeats = repeats .clone ()
577+ repeats -= starts
578+ if steps is not None :
579+ repeats *= steps .sign ()
580+ steps_abs = steps .abs ()
581+ repeats = (repeats + steps_abs - 1 ).div (steps_abs ,
582+ rounding_mode = "floor" )
583+ repeats = repeats .clip (min = 0 ) # ignore invalid ranges
584+ range_ends = repeats - 1 # last element in each range
585+ if steps is not None :
586+ range_ends *= steps
587+ if starts is not None :
588+ range_ends += starts
589+ prev_range_ends = range_ends .roll (
590+ 1 ) # last element in preceding range (or 0)
591+ prev_range_ends [0 ].fill_ (0 )
592+ ones = torch .ones ((), dtype = ends .dtype , device = ends .device )
593+ zeros = torch .zeros ((), dtype = ends .dtype , device = ends .device )
594+ if steps is None :
595+ steps = ones .broadcast_to (ends .shape )
596+ jumps = - prev_range_ends # delta from one range to the next
597+ if starts is not None :
598+ jumps += starts
599+ # NB: Apply correction for empty ranges
600+ jumps_corrections = torch .where (repeats == 0 , jumps ,
601+ zeros ).cumsum (0 , dtype = ends .dtype )
602+ jumps += jumps_corrections
603+ seq = torch .cat ((jumps .unsqueeze (- 1 ), steps .unsqueeze (- 1 )), dim = 1 ).view (- 1 )
604+ #
605+ # 2. Construct output via torch.repeat_interleave() and torch.cumsum()
606+ # NB: For a resulting empty range, repeats - 1 == -1. In this case, we
607+ # should set repeats for delta and increment both to 0 instead.
608+ jump_repeats = torch .where (repeats == 0 , zeros , ones )
609+ step_repeats = torch .where (repeats == 0 , zeros , repeats - 1 )
610+ seq_repeats = torch .cat (
611+ (jump_repeats .unsqueeze (- 1 ), step_repeats .unsqueeze (- 1 )),
612+ dim = 1 ).view (- 1 )
613+ seq = seq .repeat_interleave (seq_repeats , output_size = output_length_arg )
614+ seq = seq .cumsum (0 , dtype = ends .dtype )
615+ return seq
0 commit comments