Skip to content

Commit 5ec9ada

Browse files
authored
test: Fix fork_rng context (#556)
* It should now work with other devices than cpu and cuda.
1 parent b6880cd commit 5ec9ada

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/utils/contexts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@contextmanager
1212
def fork_rng(seed: int = 0) -> Generator[Any, None, None]:
13-
devices = [DEVICE] if DEVICE.type == "cuda" else []
13+
devices = [] if DEVICE.type == "cpu" else [DEVICE]
1414
with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx:
1515
torch.manual_seed(seed)
1616
yield ctx

0 commit comments

Comments
 (0)