@@ -35,8 +35,9 @@ class TruncatedStandardNormal(Distribution):
3535
3636 def __init__ (self , a , b , validate_args = None , device = None ):
3737 self .a , self .b = broadcast_all (a , b )
38- self .a = self .a .to (device )
39- self .b = self .b .to (device )
38+ _non_blocking = device is not None and torch .device (device ).type == "cuda"
39+ self .a = self .a .to (device , non_blocking = _non_blocking )
40+ self .b = self .b .to (device , non_blocking = _non_blocking )
4041 if isinstance (a , Number ) and isinstance (b , Number ):
4142 batch_shape = torch .Size ()
4243 else :
@@ -146,8 +147,9 @@ class TruncatedNormal(TruncatedStandardNormal):
146147 def __init__ (self , loc , scale , a , b , validate_args = None , device = None ):
147148 scale = scale .clamp_min (self .eps )
148149 self .loc , self .scale , a , b = broadcast_all (loc , scale , a , b )
149- a = a .to (device )
150- b = b .to (device )
150+ _non_blocking = device is not None and torch .device (device ).type == "cuda"
151+ a = a .to (device , non_blocking = _non_blocking )
152+ b = b .to (device , non_blocking = _non_blocking )
151153 self ._non_std_a = a
152154 self ._non_std_b = b
153155 a = (a - self .loc ) / self .scale
0 commit comments