@@ -10395,6 +10395,21 @@ class ActionDiscretizer(Transform):
1039510395 >>> assert (r["action"] < base_env.action_spec.high).all()
1039610396 >>> assert (r["action"] > base_env.action_spec.low).all()
1039710397
10398+ .. note:: Custom Sampling Strategies
10399+
10400+ To implement a custom sampling strategy beyond the built-in options
10401+ (``MEDIAN``, ``LOW``, ``HIGH``, ``RANDOM``), subclass ``ActionDiscretizer``
10402+ and override the :meth:`~ActionDiscretizer.custom_arange` method. This
10403+ method computes the normalized interval positions (values in ``[0, 1)``)
10404+ that determine where each discrete action maps within the continuous
10405+ action interval.
10406+
10407+ Example:
10408+ >>> class LogSpacedActionDiscretizer(ActionDiscretizer):
10409+ ... def custom_arange(self, nint, device):
10410+ ... # Use logarithmic spacing instead of linear
10411+ ... return torch.logspace(-2, 0, nint, device=device) - 0.01
10412+
1039810413 """
1039910414
1040010415 class SamplingStrategy (IntEnum ):
@@ -10441,7 +10456,33 @@ def _indent(s):
1044110456 f"\n { _indent (out_action_key )} ,\n { _indent (sampling )} ,\n { _indent (categorical )} )"
1044210457 )
1044310458
10444- def _custom_arange (self , nint , device ):
10459+ def custom_arange (self , nint , device ):
10460+ """Compute the normalized interval positions for discretization.
10461+
10462+ This method generates values in the range [0, 1) that determine where
10463+ each discrete action maps within the continuous action interval.
10464+
10465+ Override this method in a subclass to implement custom sampling
10466+ strategies beyond the built-in ``MEDIAN``, ``LOW``, ``HIGH``, and
10467+ ``RANDOM`` strategies.
10468+
10469+ Args:
10470+ nint (int): the number of intervals (discrete actions) for this
10471+ action dimension.
10472+ device (torch.device): the device on which to create the tensor.
10473+
10474+ Returns:
10475+ torch.Tensor: a 1D tensor of shape ``(nint,)`` with values in
10476+ ``[0, 1)`` representing the normalized positions within each
10477+ interval.
10478+
10479+ Example:
10480+ >>> class CustomActionDiscretizer(ActionDiscretizer):
10481+ ... def custom_arange(self, nint, device):
10482+ ... # Custom sampling: use logarithmic spacing
10483+ ... return torch.logspace(-2, 0, nint, device=device) - 0.01
10484+
10485+ """
1044510486 result = torch .arange (
1044610487 start = 0.0 ,
1044710488 end = 1.0 ,
@@ -10491,7 +10532,7 @@ def transform_input_spec(self, input_spec):
1049110532
1049210533 if isinstance (num_intervals , int ):
1049310534 arange = (
10494- self ._custom_arange (num_intervals , action_spec .device ).expand (
10535+ self .custom_arange (num_intervals , action_spec .device ).expand (
1049510536 (* n_act , num_intervals )
1049610537 )
1049710538 * interval
@@ -10502,7 +10543,7 @@ def transform_input_spec(self, input_spec):
1050210543 self .register_buffer ("intervals" , low + arange )
1050310544 else :
1050410545 arange = [
10505- self ._custom_arange (_num_intervals , action_spec .device ) * interval
10546+ self .custom_arange (_num_intervals , action_spec .device ) * interval
1050610547 for _num_intervals , interval in zip (
1050710548 num_intervals .tolist (), interval .unbind (- 2 )
1050810549 )
0 commit comments