We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 737f7c9 commit e7c23deCopy full SHA for e7c23de
1 file changed
torchcontrol/system/system_cfg.py
@@ -4,6 +4,7 @@
4
"""
5
from __future__ import annotations
6
7
+import torch
8
from dataclasses import MISSING, dataclass
9
from .system_base import SystemBase
10
@@ -47,4 +48,4 @@ def __post_init__(self):
47
48
assert self.state_dim > 0, "state_dim must be greater than 0"
49
assert self.action_dim > 0, "action_dim must be greater than 0"
50
assert self.dt > 0, "dt must be greater than 0"
- assert self.device in ["cpu", "cuda"], "device must be 'cpu' or 'cuda'"
51
+ self.device = torch.device(self.device)
0 commit comments