Skip to content

Commit 7d6c5e3

Browse files
authored
[Feature] Make custom_range public in ActionDiscretizer (#3333)
1 parent 9ee9e90 commit 7d6c5e3

1 file changed

Lines changed: 44 additions & 3 deletions

File tree

torchrl/envs/transforms/transforms.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10395,6 +10395,21 @@ class ActionDiscretizer(Transform):
1039510395
>>> assert (r["action"] < base_env.action_spec.high).all()
1039610396
>>> assert (r["action"] > base_env.action_spec.low).all()
1039710397
10398+
.. note:: Custom Sampling Strategies
10399+
10400+
To implement a custom sampling strategy beyond the built-in options
10401+
(``MEDIAN``, ``LOW``, ``HIGH``, ``RANDOM``), subclass ``ActionDiscretizer``
10402+
and override the :meth:`~ActionDiscretizer.custom_arange` method. This
10403+
method computes the normalized interval positions (values in ``[0, 1)``)
10404+
that determine where each discrete action maps within the continuous
10405+
action interval.
10406+
10407+
Example:
10408+
>>> class LogSpacedActionDiscretizer(ActionDiscretizer):
10409+
... def custom_arange(self, nint, device):
10410+
... # Use logarithmic spacing instead of linear
10411+
... return torch.logspace(-2, 0, nint, device=device) - 0.01
10412+
1039810413
"""
1039910414

1040010415
class SamplingStrategy(IntEnum):
@@ -10441,7 +10456,33 @@ def _indent(s):
1044110456
f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
1044210457
)
1044310458

10444-
def _custom_arange(self, nint, device):
10459+
def custom_arange(self, nint, device):
10460+
"""Compute the normalized interval positions for discretization.
10461+
10462+
This method generates values in the range [0, 1) that determine where
10463+
each discrete action maps within the continuous action interval.
10464+
10465+
Override this method in a subclass to implement custom sampling
10466+
strategies beyond the built-in ``MEDIAN``, ``LOW``, ``HIGH``, and
10467+
``RANDOM`` strategies.
10468+
10469+
Args:
10470+
nint (int): the number of intervals (discrete actions) for this
10471+
action dimension.
10472+
device (torch.device): the device on which to create the tensor.
10473+
10474+
Returns:
10475+
torch.Tensor: a 1D tensor of shape ``(nint,)`` with values in
10476+
``[0, 1)`` representing the normalized positions within each
10477+
interval.
10478+
10479+
Example:
10480+
>>> class CustomActionDiscretizer(ActionDiscretizer):
10481+
... def custom_arange(self, nint, device):
10482+
... # Custom sampling: use logarithmic spacing
10483+
... return torch.logspace(-2, 0, nint, device=device) - 0.01
10484+
10485+
"""
1044510486
result = torch.arange(
1044610487
start=0.0,
1044710488
end=1.0,
@@ -10491,7 +10532,7 @@ def transform_input_spec(self, input_spec):
1049110532

1049210533
if isinstance(num_intervals, int):
1049310534
arange = (
10494-
self._custom_arange(num_intervals, action_spec.device).expand(
10535+
self.custom_arange(num_intervals, action_spec.device).expand(
1049510536
(*n_act, num_intervals)
1049610537
)
1049710538
* interval
@@ -10502,7 +10543,7 @@ def transform_input_spec(self, input_spec):
1050210543
self.register_buffer("intervals", low + arange)
1050310544
else:
1050410545
arange = [
10505-
self._custom_arange(_num_intervals, action_spec.device) * interval
10546+
self.custom_arange(_num_intervals, action_spec.device) * interval
1050610547
for _num_intervals, interval in zip(
1050710548
num_intervals.tolist(), interval.unbind(-2)
1050810549
)

0 commit comments

Comments
 (0)