Skip to content

Commit ec8b5f7

Browse files
committed
test: Fix fork_rng context
It should now work with other devices than cpu and cuda
1 parent b6880cd commit ec8b5f7

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

tests/utils/contexts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@contextmanager
1212
def fork_rng(seed: int = 0) -> Generator[Any, None, None]:
13-
devices = [DEVICE] if DEVICE.type == "cuda" else []
13+
devices = [] if DEVICE.type == "cpu" else [DEVICE]
1414
with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx:
1515
torch.manual_seed(seed)
1616
yield ctx

0 commit comments

Comments
 (0)