Skip to content

Commit 7168cbf

Browse files
authored
kernels-data: add and expose supported backend archs (#480)
* kernels-data: add and expose supported backend archs * Fix stub
1 parent a9fc299 commit 7168cbf

5 files changed

Lines changed: 73 additions & 19 deletions

File tree

kernel-builder/src/pyproject/common.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub fn write_metadata(general: &General, file_set: &mut FileSet) -> Result<()> {
3838
upstream: general.upstream.clone(),
3939
python_depends,
4040
backend: BackendInfo {
41+
archs: None,
4142
backend_type: *backend,
4243
},
4344
};

kernels-data/bindings/python/kernels_data.pyi

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Type stubs for kernels_data module."""
22

3+
import os
34
from enum import Enum
45
from typing import Optional, final
5-
import os
66

7-
__all__ = ["Backend", "KernelName", "Metadata", "Version", "__version__"]
7+
__all__ = ["Backend", "BackendInfo", "KernelName", "Metadata", "Version", "__version__"]
88

99
__version__: str
1010

@@ -36,6 +36,22 @@ class Backend(Enum):
3636
def __str__(self) -> str: ...
3737
def __repr__(self) -> str: ...
3838

39+
@final
40+
class BackendInfo:
41+
"""Backend information."""
42+
43+
@property
44+
def backend_type(self) -> Backend:
45+
"""Return the backend type."""
46+
...
47+
48+
@property
49+
def archs(self) -> Optional[list[str]]:
50+
"""Optional list of target architectures."""
51+
...
52+
53+
def __repr__(self) -> str: ...
54+
3955
@final
4056
class Version:
4157
"""A dotted numeric version (e.g. ``12.8.0``).
@@ -105,5 +121,5 @@ class Metadata:
105121
@property
106122
def python_depends(self) -> list[str]: ...
107123
@property
108-
def backend(self) -> Backend: ...
124+
def backend(self) -> BackendInfo: ...
109125
def __repr__(self) -> str: ...

kernels-data/bindings/python/src/lib.rs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::PathBuf;
22
use std::str::FromStr;
33

44
use kernels_data::config::{Backend, KernelName};
5-
use kernels_data::metadata::{Metadata, parse_metadata};
5+
use kernels_data::metadata::{BackendInfo, Metadata, parse_metadata};
66
use kernels_data::version::Version;
77
use pyo3::Bound as PyBound;
88
use pyo3::exceptions::PyValueError;
@@ -146,6 +146,44 @@ impl PyBackend {
146146
}
147147
}
148148

149+
/// Backend information
150+
#[pyclass(name = "BackendInfo", frozen)]
151+
#[derive(Clone, Debug)]
152+
struct PyBackendInfo {
153+
backend_type: PyBackend,
154+
archs: Option<Vec<String>>,
155+
}
156+
157+
impl From<BackendInfo> for PyBackendInfo {
158+
fn from(backend_info: BackendInfo) -> Self {
159+
Self {
160+
backend_type: backend_info.backend_type.into(),
161+
archs: backend_info.archs,
162+
}
163+
}
164+
}
165+
166+
#[pymethods]
167+
impl PyBackendInfo {
168+
fn __repr__(&self) -> String {
169+
format!(
170+
"BackendInfo(backend_type={}, archs={:?})",
171+
self.backend_type.__repr__(),
172+
self.archs
173+
)
174+
}
175+
176+
#[getter]
177+
fn backend_type(&self) -> PyBackend {
178+
self.backend_type
179+
}
180+
181+
#[getter]
182+
fn archs(&self) -> Option<&[String]> {
183+
self.archs.as_deref()
184+
}
185+
}
186+
149187
/// Parsed `metadata.json` for a kernel build variant.
150188
#[pyclass(name = "Metadata", frozen)]
151189
#[derive(Clone, Debug)]
@@ -154,7 +192,7 @@ struct PyMetadata {
154192
license: Option<String>,
155193
upstream: Option<String>,
156194
python_depends: Vec<String>,
157-
backend: PyBackend,
195+
backend: PyBackendInfo,
158196
}
159197

160198
impl From<Metadata> for PyMetadata {
@@ -164,7 +202,7 @@ impl From<Metadata> for PyMetadata {
164202
license: m.license,
165203
upstream: m.upstream.map(|u| u.to_string()),
166204
python_depends: m.python_depends,
167-
backend: m.backend.backend_type.into(),
205+
backend: m.backend.into(),
168206
}
169207
}
170208
}
@@ -202,8 +240,8 @@ impl PyMetadata {
202240
}
203241

204242
#[getter]
205-
fn backend(&self) -> PyBackend {
206-
self.backend
243+
fn backend(&self) -> PyBackendInfo {
244+
self.backend.clone()
207245
}
208246

209247
fn __repr__(&self) -> String {
@@ -221,6 +259,7 @@ impl PyMetadata {
221259
#[pyo3::pymodule(name = "kernels_data")]
222260
fn kernels_data_py(m: &PyBound<'_, PyModule>) -> PyResult<()> {
223261
m.add_class::<PyBackend>()?;
262+
m.add_class::<PyBackendInfo>()?;
224263
m.add_class::<PyKernelName>()?;
225264
m.add_class::<PyMetadata>()?;
226265
m.add_class::<PyVersion>()?;

kernels-data/bindings/python/tests/test_kernels_data.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_metadata_load_full(tmp_path):
9494
"license": "Apache-2.0",
9595
"upstream": "https://github.com/example/kernel",
9696
"python-depends": ["torch"],
97-
"backend": {"type": "cuda"},
97+
"backend": {"type": "cuda", "archs": ["9.0", "10.0"]},
9898
}
9999
)
100100
)
@@ -103,28 +103,25 @@ def test_metadata_load_full(tmp_path):
103103
assert m.license == "Apache-2.0"
104104
assert m.upstream == "https://github.com/example/kernel"
105105
assert m.python_depends == ["torch"]
106-
assert m.backend == Backend.CUDA
106+
assert m.backend.backend_type == Backend.CUDA
107+
assert m.backend.archs == ["9.0", "10.0"]
107108

108109

109110
def test_metadata_load_minimal(tmp_path):
110111
path = tmp_path / "metadata.json"
111-
path.write_text(
112-
json.dumps({"python-depends": [], "backend": {"type": "cpu"}})
113-
)
112+
path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cpu"}}))
114113
m = Metadata.load(path)
115114
assert m.version is None
116115
assert m.license is None
117116
assert m.upstream is None
118117
assert m.python_depends == []
119-
assert m.backend == Backend.CPU
118+
assert m.backend.backend_type == Backend.CPU
120119

121120

122121
def test_metadata_load_cann(tmp_path):
123122
path = tmp_path / "metadata.json"
124-
path.write_text(
125-
json.dumps({"python-depends": [], "backend": {"type": "cann"}})
126-
)
127-
assert Metadata.load(path).backend == Backend.CANN
123+
path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cann"}}))
124+
assert Metadata.load(path).backend.backend_type == Backend.CANN
128125

129126

130127
def test_metadata_load_unknown_field_rejected(tmp_path):
@@ -155,7 +152,7 @@ def test_metadata_load(tmp_path):
155152
**{"python-depends": ["torch"], "backend": {"type": "cuda"}},
156153
)
157154
m = Metadata.load(path)
158-
assert m.backend == Backend.CUDA
155+
assert m.backend.backend_type == Backend.CUDA
159156

160157

161158
def test_metadata_load_missing_file(tmp_path):

kernels-data/src/metadata.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::config::Backend;
1010
pub struct BackendInfo {
1111
#[serde(rename = "type")]
1212
pub backend_type: Backend,
13+
pub archs: Option<Vec<String>>,
1314
}
1415

1516
#[derive(Debug, Deserialize, Serialize)]

0 commit comments

Comments
 (0)