Skip to content

Commit 58a3d69

Browse files
committed
Update
[ghstack-poisoned]
1 parent fada3f6 commit 58a3d69

3 files changed

Lines changed: 16 additions & 11 deletions

File tree

torchrl/modules/distributions/continuous.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,15 @@ def __init__(
370370
if not all(high > low):
371371
raise RuntimeError(err_msg)
372372

373+
_non_blocking = loc.device.type == "cuda"
373374
if not isinstance(high, torch.Tensor):
374375
high = torch.as_tensor(high, device=loc.device)
375376
elif high.device != loc.device:
376-
high = high.to(loc.device)
377+
high = high.to(loc.device, non_blocking=_non_blocking)
377378
if not isinstance(low, torch.Tensor):
378379
low = torch.as_tensor(low, device=loc.device)
379380
elif low.device != loc.device:
380-
low = low.to(loc.device)
381+
low = low.to(loc.device, non_blocking=_non_blocking)
381382
if not is_compiling() and not safe_is_current_stream_capturing():
382383
self.non_trivial_max = (high != 1.0).any()
383384
self.non_trivial_min = (low != -1.0).any()
@@ -391,10 +392,10 @@ def __init__(
391392
self.upscale = (
392393
upscale
393394
if not isinstance(upscale, torch.Tensor)
394-
else upscale.to(self.device)
395+
else upscale.to(self.device, non_blocking=_non_blocking)
395396
)
396397

397-
low = low.to(loc.device)
398+
low = low.to(loc.device, non_blocking=_non_blocking)
398399
self.low = low
399400
self.high = high
400401

torchrl/modules/distributions/truncated_normal.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ class TruncatedStandardNormal(Distribution):
3535

3636
def __init__(self, a, b, validate_args=None, device=None):
3737
self.a, self.b = broadcast_all(a, b)
38-
self.a = self.a.to(device)
39-
self.b = self.b.to(device)
38+
_non_blocking = device is not None and torch.device(device).type == "cuda"
39+
self.a = self.a.to(device, non_blocking=_non_blocking)
40+
self.b = self.b.to(device, non_blocking=_non_blocking)
4041
if isinstance(a, Number) and isinstance(b, Number):
4142
batch_shape = torch.Size()
4243
else:
@@ -146,8 +147,9 @@ class TruncatedNormal(TruncatedStandardNormal):
146147
def __init__(self, loc, scale, a, b, validate_args=None, device=None):
147148
scale = scale.clamp_min(self.eps)
148149
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
149-
a = a.to(device)
150-
b = b.to(device)
150+
_non_blocking = device is not None and torch.device(device).type == "cuda"
151+
a = a.to(device, non_blocking=_non_blocking)
152+
b = b.to(device, non_blocking=_non_blocking)
151153
self._non_std_a = a
152154
self._non_std_b = b
153155
a = (a - self.loc) / self.scale

torchrl/modules/distributions/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@
1616

1717
def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float:
1818
if isinstance(elt, torch.Tensor):
19-
return elt.to(device)
19+
_non_blocking = device is not None and torch.device(device).type == "cuda"
20+
return elt.to(device, non_blocking=_non_blocking)
2021
return elt
2122

2223

2324
def _cast_transform_device(transform, device):
2425
if transform is None:
2526
return transform
26-
elif isinstance(transform, d.ComposeTransform):
27+
_non_blocking = device is not None and torch.device(device).type == "cuda"
28+
if isinstance(transform, d.ComposeTransform):
2729
for i, t in enumerate(transform.parts):
2830
transform.parts[i] = _cast_transform_device(t, device)
2931
elif isinstance(transform, d.Transform):
3032
for attribute in dir(transform):
3133
value = getattr(transform, attribute)
3234
if isinstance(value, torch.Tensor):
33-
setattr(transform, attribute, value.to(device))
35+
setattr(transform, attribute, value.to(device, non_blocking=_non_blocking))
3436
return transform
3537
else:
3638
raise TypeError(

0 commit comments

Comments
 (0)