Skip to content

Commit c27aeb0

Browse files
committed
Add docs for get_installed_torch_platform and get_device in utils
1 parent cd099aa commit c27aeb0

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

API.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,34 @@ gpus = get_gpus() # Returns a list of `torchruntime.device_db.GPU` instances co
2323

2424
**Important:** This API could break in the future, so if you're writing a program using this, please open a new Issue on this repo and let me know what you're trying to do.
2525

26-
## Get torch platform
26+
## Get torch platform (given any GPU)
2727
This will return the recommended torch platform to use for the PC. It will analyze the GPUs and OS on the PC, and suggest the most-performant version of torch for that.
2828

29+
*Note: this is different from the installed torch platfrom (see the next paragraph)!*
30+
2931
E.g. `cu124` or `rocm6.1` or `directml` or `ipex` or `xpu` or `cpu`.
3032

3133
```py
3234
from torchruntime.platform_detection import get_torch_platform
3335

3436
torch_platform = get_torch_platform(gpus) # use `torchruntime.device_db.get_gpus()` to get a list of recognized GPUs
3537
```
38+
39+
## Get installed torch platform
40+
This will return the installed torch platform, if any. E.g. "cuda", "mps", "cpu" etc.
41+
42+
```py
43+
from torchruntime.utils import get_installed_torch_platform
44+
45+
torch_platform = get_installed_torch_platform()[0]
46+
```
47+
48+
## Get torch device
49+
This will return an instance of `torch.device` for the given device index, from the currently installed torch platform.
50+
51+
```py
52+
from torchruntime.utils import get_device
53+
54+
device1 = get_device(0)
55+
device2 = get_device("cuda:1")
56+
```

0 commit comments

Comments
 (0)