Skip to content

Commit 8941f82

Browse files
GitGlimpse895GitGlimpse895
andauthored
fix/rocm properties export (#487)
* layer/device: fix Device properties type annotation and add ROCMProperties validation * layer: export ROCMProperties from layer package * kernels: export ROCMProperties from top-level package * layer/device: add ROCMProperties type annotation and validation guard --------- Co-authored-by: GitGlimpse895 <sayakmondal432@gmail.com>
1 parent 559c412 commit 8941f82

3 files changed

Lines changed: 11 additions & 2 deletions

File tree

kernels/src/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
LockedFuncRepository,
1717
LockedLayerRepository,
1818
Mode,
19+
ROCMProperties,
1920
kernelize,
2021
register_kernel_mapping,
2122
replace_kernel_forward_from_hub,
@@ -42,6 +43,7 @@
4243
"Benchmark",
4344
"CUDAProperties",
4445
"Device",
46+
"ROCMProperties",
4547
"FuncRepository",
4648
"LayerRepository",
4749
"LoadedKernel",

kernels/src/kernels/layer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .device import CUDAProperties, Device
1+
from .device import CUDAProperties, Device, ROCMProperties
22
from .func import (
33
FuncRepository,
44
LocalFuncRepository,
@@ -22,6 +22,7 @@
2222
__all__ = [
2323
"CUDAProperties",
2424
"Device",
25+
"ROCMProperties",
2526
"FuncRepository",
2627
"LayerRepository",
2728
"LocalFuncRepository",

kernels/src/kernels/layer/device.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,18 @@ class Device:
141141
"""
142142

143143
type: str
144-
properties: CUDAProperties | None = None
144+
properties: CUDAProperties | ROCMProperties | None = None
145145

146146
def __post_init__(self):
147147
if self.properties is not None and isinstance(self.properties, CUDAProperties):
148148
if self.type != "cuda":
149149
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
150+
if self.properties is not None and isinstance(self.properties, ROCMProperties):
151+
if self.type != "rocm":
152+
raise ValueError("ROCMProperties is only supported for 'rocm' devices.")
153+
if self.properties is not None and isinstance(self.properties, ROCMProperties):
154+
if self.type != "rocm":
155+
raise ValueError("ROCMProperties is only supported for 'rocm' devices.")
150156

151157
def __eq__(self, other):
152158
if not isinstance(other, Device):

0 commit comments

Comments
 (0)