-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathtest_kernels_data.py
More file actions
159 lines (121 loc) · 4.51 KB
/
test_kernels_data.py
File metadata and controls
159 lines (121 loc) · 4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import pytest
from kernels_data import Backend, KernelName, Metadata, Version
def _write_metadata(path, **fields):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(fields))
return path
def test_version_parse_and_normalize():
assert str(Version.from_str("12.8.0")) == "12.8"
assert str(Version.from_str("1")) == "1"
assert str(Version.from_str("1.2.3")) == "1.2.3"
def test_version_ordering_and_hash():
v1 = Version.from_str("1.2")
v2 = Version.from_str("1.2.0")
v3 = Version.from_str("1.3")
assert v1 == v2
assert v1 < v3
assert hash(v1) == hash(v2)
assert {v1, v2, v3} == {v1, v3}
def test_version_invalid():
with pytest.raises(ValueError):
Version.from_str("abc")
with pytest.raises(ValueError):
Version.from_str("")
def test_kernel_name_valid():
n = KernelName("my-kernel")
assert str(n) == "my-kernel"
assert n.python_name == "my_kernel"
def test_kernel_name_hash_and_eq():
assert KernelName("flash-attention") == KernelName("flash-attention")
assert {KernelName("a1"), KernelName("a1")} == {KernelName("a1")}
def test_kernel_name_invalid():
with pytest.raises(ValueError):
KernelName("My-Kernel")
with pytest.raises(ValueError):
KernelName("1kernel")
with pytest.raises(ValueError):
KernelName("-kernel")
def test_backend_from_str_and_repr():
assert Backend.from_str("cuda") == Backend.CUDA
assert Backend.from_str("CUDA") == Backend.CUDA
assert str(Backend.CUDA) == "cuda"
assert repr(Backend.CUDA) == "Backend.CUDA"
def test_backend_hash():
d = {Backend.CUDA: 1, Backend.CPU: 2}
assert d[Backend.CUDA] == 1
def test_backend_unknown():
with pytest.raises(ValueError):
Backend.from_str("tpu")
def test_backend_all_variants_and_casing():
assert str(Backend.Metal) == "metal"
assert repr(Backend.Metal) == "Backend.Metal"
assert str(Backend.Neuron) == "neuron"
assert repr(Backend.Neuron) == "Backend.Neuron"
assert str(Backend.ROCm) == "rocm"
assert repr(Backend.ROCm) == "Backend.ROCm"
assert repr(Backend.XPU) == "Backend.XPU"
assert repr(Backend.CANN) == "Backend.CANN"
assert Backend.from_str("cann") == Backend.CANN
assert Backend.from_str("ROCM") == Backend.ROCm
assert Backend.from_str("metal") == Backend.Metal
def test_metadata_load_full(tmp_path):
path = tmp_path / "metadata.json"
path.write_text(
json.dumps(
{
"version": 1,
"license": "Apache-2.0",
"upstream": "https://github.com/example/kernel",
"python-depends": ["torch"],
"backend": {"type": "cuda", "archs": ["9.0", "10.0"]},
}
)
)
m = Metadata.load(path)
assert m.version == 1
assert m.license == "Apache-2.0"
assert m.upstream == "https://github.com/example/kernel"
assert m.python_depends == ["torch"]
assert m.backend.backend_type == Backend.CUDA
assert m.backend.archs == ["9.0", "10.0"]
def test_metadata_load_minimal(tmp_path):
path = tmp_path / "metadata.json"
path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cpu"}}))
m = Metadata.load(path)
assert m.version is None
assert m.license is None
assert m.upstream is None
assert m.python_depends == []
assert m.backend.backend_type == Backend.CPU
def test_metadata_load_cann(tmp_path):
path = tmp_path / "metadata.json"
path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cann"}}))
assert Metadata.load(path).backend.backend_type == Backend.CANN
def test_metadata_load_unknown_field_accepted(tmp_path):
path = tmp_path / "metadata.json"
path.write_text(
json.dumps(
{
"python-depends": [],
"backend": {"type": "cpu"},
"surprise": "not allowed",
}
)
)
Metadata.load(path)
def test_metadata_load_malformed(tmp_path):
path = tmp_path / "metadata.json"
path.write_text("{not json")
with pytest.raises(ValueError):
Metadata.load(path)
def test_metadata_load(tmp_path):
path = _write_metadata(
tmp_path / "variant" / "metadata.json",
**{"python-depends": ["torch"], "backend": {"type": "cuda"}},
)
m = Metadata.load(path)
assert m.backend.backend_type == Backend.CUDA
def test_metadata_load_missing_file(tmp_path):
with pytest.raises(ValueError):
Metadata.load(tmp_path / "does-not-exist.json")