Skip to content

Commit cb8f4a0

Browse files
committed
[Refactor] Make primitive dtypes Python classes wrapping DataTypeCxx
Convert primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level variables into Python classes with a PrimitiveMeta metaclass. Each class has a .cxx attribute holding the underlying DataTypeCxx, and the metaclass delegates __eq__, __hash__, __getattr__ for backward compatibility. Update cook_dtype, to_quadrants_type, MAP_TYPE_IDS, and type utility functions to handle the new class-based types. Add PrimitiveBase checks in expr_init and quant.py.
1 parent 9b5e1a8 commit cb8f4a0

6 files changed

Lines changed: 272 additions & 144 deletions

File tree

python/quadrants/lang/impl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from quadrants.types.enums import SNodeGradType
6363
from quadrants.types.ndarray_type import NdarrayType
6464
from quadrants.types.primitive_types import (
65+
PrimitiveBase,
6566
all_types,
6667
f16,
6768
f32,
@@ -110,6 +111,8 @@ def expr_init(rhs):
110111
return dict((key, expr_init(val)) for key, val in rhs.items())
111112
if isinstance(rhs, _qd_core.DataTypeCxx):
112113
return rhs
114+
if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase):
115+
return rhs.cxx
113116
if isinstance(rhs, _qd_core.Arch):
114117
return rhs
115118
if isinstance(rhs, _Ndrange):

python/quadrants/lang/util.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,59 @@
1212
from quadrants.lang import impl
1313
from quadrants.types import Template
1414
from quadrants.types.primitive_types import (
15+
PrimitiveBase,
1516
all_types,
1617
f16,
18+
f16_cxx,
1719
f32,
20+
f32_cxx,
1821
f64,
22+
f64_cxx,
1923
i8,
24+
i8_cxx,
2025
i16,
26+
i16_cxx,
2127
i32,
28+
i32_cxx,
2229
i64,
30+
i64_cxx,
2331
u1,
32+
u1_cxx,
2433
u8,
34+
u8_cxx,
2535
u16,
36+
u16_cxx,
2637
u32,
38+
u32_cxx,
2739
u64,
40+
u64_cxx,
2841
)
2942

30-
MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types}
43+
MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types}
44+
_all_cxx_objs = (
45+
f16_cxx,
46+
f32_cxx,
47+
f64_cxx,
48+
i8_cxx,
49+
i16_cxx,
50+
i32_cxx,
51+
i64_cxx,
52+
u1_cxx,
53+
u8_cxx,
54+
u16_cxx,
55+
u32_cxx,
56+
u64_cxx,
57+
)
58+
for _cxx in _all_cxx_objs:
59+
MAP_TYPE_IDS[id(_cxx)] = _cxx
60+
61+
# Pre-computed id-based cache for cook_dtype hot path.
62+
# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result.
63+
_cook_cache: dict[int, _qd_core.DataTypeCxx] = {}
64+
for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64):
65+
_cook_cache[id(_cls)] = _cls.cxx
66+
for _cxx in _all_cxx_objs:
67+
_cook_cache[id(_cxx)] = _cxx
3168

3269

3370
def has_pytorch():
@@ -177,71 +214,74 @@ def to_quadrants_type(dt):
177214
dt (DataType): The desired data type to convert.
178215
179216
Returns:
180-
DataType: The counterpart data type in quadrants.
217+
DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx).
181218
182219
"""
183220
_type = type(dt)
184221
if _type is int:
185-
return MAP_TYPE_IDS[dt]
222+
return cook_dtype(MAP_TYPE_IDS[dt])
223+
224+
if isinstance(dt, type) and issubclass(dt, PrimitiveBase):
225+
return dt.cxx
186226

187227
if issubclass(_type, _qd_core.DataTypeCxx):
188228
return dt
189229

190230
if dt == np.float32:
191-
return f32
231+
return f32.cxx
192232
if dt == np.float64:
193-
return f64
233+
return f64.cxx
194234
if dt == np.int32:
195-
return i32
235+
return i32.cxx
196236
if dt == np.int64:
197-
return i64
237+
return i64.cxx
198238
if dt == np.int8:
199-
return i8
239+
return i8.cxx
200240
if dt == np.int16:
201-
return i16
241+
return i16.cxx
202242
if dt == np.bool_:
203-
return u1
243+
return u1.cxx
204244
if dt == np.uint8:
205-
return u8
245+
return u8.cxx
206246
if dt == np.uint16:
207-
return u16
247+
return u16.cxx
208248
if dt == np.uint32:
209-
return u32
249+
return u32.cxx
210250
if dt == np.uint64:
211-
return u64
251+
return u64.cxx
212252
if dt == np.half:
213-
return f16
253+
return f16.cxx
214254

215255
if has_pytorch():
216256
import torch # pylint: disable=C0415
217257

218258
# pylint: disable=E1101
219259
if dt == torch.float32:
220-
return f32
260+
return f32.cxx
221261
if dt == torch.float64:
222-
return f64
262+
return f64.cxx
223263
if dt == torch.int32:
224-
return i32
264+
return i32.cxx
225265
if dt == torch.int64:
226-
return i64
266+
return i64.cxx
227267
if dt == torch.int8:
228-
return i8
268+
return i8.cxx
229269
if dt == torch.int16:
230-
return i16
270+
return i16.cxx
231271
if dt == torch.bool:
232-
return u1
272+
return u1.cxx
233273
if dt == torch.uint8:
234-
return u8
274+
return u8.cxx
235275
if dt == torch.float16:
236-
return f16
276+
return f16.cxx
237277

238278
if hasattr(torch, "uint16"):
239279
if dt == torch.uint16:
240-
return u16
280+
return u16.cxx
241281
if dt == torch.uint32:
242-
return u32
282+
return u32.cxx
243283
if dt == torch.uint64:
244-
return u64
284+
return u64.cxx
245285

246286
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
247287

@@ -264,8 +304,17 @@ def __hash__(self):
264304

265305

266306
def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
267-
# Convert Python dtype to CPP dtype
307+
"""Convert Python dtype to C++ DataTypeCxx.
308+
309+
Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances,
310+
and Python builtins (float, int, bool). Uses id-based cache for hot paths.
311+
"""
312+
cached = _cook_cache.get(id(dtype))
313+
if cached is not None:
314+
return cached
268315
_type = type(dtype)
316+
if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase):
317+
return dtype.cxx
269318
if issubclass(_type, _qd_core.DataTypeCxx):
270319
return dtype
271320
if issubclass(_type, _qd_core.Type):
@@ -275,7 +324,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
275324
if dtype is int:
276325
return impl.get_runtime().default_ip
277326
if dtype is bool:
278-
return u1
327+
return u1.cxx
279328
raise ValueError(f"Invalid data type {dtype}")
280329

281330

0 commit comments

Comments
 (0)