1212from quadrants .lang import impl
1313from quadrants .types import Template
1414from quadrants .types .primitive_types import (
15+ PrimitiveBase ,
1516 all_types ,
1617 f16 ,
18+ f16_cxx ,
1719 f32 ,
20+ f32_cxx ,
1821 f64 ,
22+ f64_cxx ,
1923 i8 ,
24+ i8_cxx ,
2025 i16 ,
26+ i16_cxx ,
2127 i32 ,
28+ i32_cxx ,
2229 i64 ,
30+ i64_cxx ,
2331 u1 ,
32+ u1_cxx ,
2433 u8 ,
34+ u8_cxx ,
2535 u16 ,
36+ u16_cxx ,
2637 u32 ,
38+ u32_cxx ,
2739 u64 ,
40+ u64_cxx ,
2841)
2942
30- MAP_TYPE_IDS = {id (dtype ): dtype for dtype in all_types }
43+ MAP_TYPE_IDS : dict [int , Any ] = {id (dtype ): dtype for dtype in all_types }
44+ _all_cxx_objs = (
45+ f16_cxx ,
46+ f32_cxx ,
47+ f64_cxx ,
48+ i8_cxx ,
49+ i16_cxx ,
50+ i32_cxx ,
51+ i64_cxx ,
52+ u1_cxx ,
53+ u8_cxx ,
54+ u16_cxx ,
55+ u32_cxx ,
56+ u64_cxx ,
57+ )
58+ for _cxx in _all_cxx_objs :
59+ MAP_TYPE_IDS [id (_cxx )] = _cxx
60+
61+ # Pre-computed id-based cache for cook_dtype hot path.
62+ # Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result.
63+ _cook_cache : dict [int , _qd_core .DataTypeCxx ] = {}
64+ for _cls in (f16 , f32 , f64 , i8 , i16 , i32 , i64 , u1 , u8 , u16 , u32 , u64 ):
65+ _cook_cache [id (_cls )] = _cls .cxx
66+ for _cxx in _all_cxx_objs :
67+ _cook_cache [id (_cxx )] = _cxx
3168
3269
3370def has_pytorch ():
@@ -177,71 +214,74 @@ def to_quadrants_type(dt):
177214 dt (DataType): The desired data type to convert.
178215
179216 Returns:
180- DataType : The counterpart data type in quadrants.
217+ DataTypeCxx : The counterpart data type in quadrants (always returns DataTypeCxx) .
181218
182219 """
183220 _type = type (dt )
184221 if _type is int :
185- return MAP_TYPE_IDS [dt ]
222+ return cook_dtype (MAP_TYPE_IDS [dt ])
223+
224+ if isinstance (dt , type ) and issubclass (dt , PrimitiveBase ):
225+ return dt .cxx
186226
187227 if issubclass (_type , _qd_core .DataTypeCxx ):
188228 return dt
189229
190230 if dt == np .float32 :
191- return f32
231+ return f32 . cxx
192232 if dt == np .float64 :
193- return f64
233+ return f64 . cxx
194234 if dt == np .int32 :
195- return i32
235+ return i32 . cxx
196236 if dt == np .int64 :
197- return i64
237+ return i64 . cxx
198238 if dt == np .int8 :
199- return i8
239+ return i8 . cxx
200240 if dt == np .int16 :
201- return i16
241+ return i16 . cxx
202242 if dt == np .bool_ :
203- return u1
243+ return u1 . cxx
204244 if dt == np .uint8 :
205- return u8
245+ return u8 . cxx
206246 if dt == np .uint16 :
207- return u16
247+ return u16 . cxx
208248 if dt == np .uint32 :
209- return u32
249+ return u32 . cxx
210250 if dt == np .uint64 :
211- return u64
251+ return u64 . cxx
212252 if dt == np .half :
213- return f16
253+ return f16 . cxx
214254
215255 if has_pytorch ():
216256 import torch # pylint: disable=C0415
217257
218258 # pylint: disable=E1101
219259 if dt == torch .float32 :
220- return f32
260+ return f32 . cxx
221261 if dt == torch .float64 :
222- return f64
262+ return f64 . cxx
223263 if dt == torch .int32 :
224- return i32
264+ return i32 . cxx
225265 if dt == torch .int64 :
226- return i64
266+ return i64 . cxx
227267 if dt == torch .int8 :
228- return i8
268+ return i8 . cxx
229269 if dt == torch .int16 :
230- return i16
270+ return i16 . cxx
231271 if dt == torch .bool :
232- return u1
272+ return u1 . cxx
233273 if dt == torch .uint8 :
234- return u8
274+ return u8 . cxx
235275 if dt == torch .float16 :
236- return f16
276+ return f16 . cxx
237277
238278 if hasattr (torch , "uint16" ):
239279 if dt == torch .uint16 :
240- return u16
280+ return u16 . cxx
241281 if dt == torch .uint32 :
242- return u32
282+ return u32 . cxx
243283 if dt == torch .uint64 :
244- return u64
284+ return u64 . cxx
245285
246286 raise RuntimeError (f"PyTorch doesn't support { dt .to_string ()} data type before version 2.3.0." )
247287
@@ -264,8 +304,17 @@ def __hash__(self):
264304
265305
266306def cook_dtype (dtype : Any ) -> _qd_core .DataTypeCxx :
267- # Convert Python dtype to CPP dtype
307+ """Convert Python dtype to C++ DataTypeCxx.
308+
309+ Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances,
310+ and Python builtins (float, int, bool). Uses id-based cache for hot paths.
311+ """
312+ cached = _cook_cache .get (id (dtype ))
313+ if cached is not None :
314+ return cached
268315 _type = type (dtype )
316+ if isinstance (dtype , type ) and issubclass (dtype , PrimitiveBase ):
317+ return dtype .cxx
269318 if issubclass (_type , _qd_core .DataTypeCxx ):
270319 return dtype
271320 if issubclass (_type , _qd_core .Type ):
@@ -275,7 +324,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
275324 if dtype is int :
276325 return impl .get_runtime ().default_ip
277326 if dtype is bool :
278- return u1
327+ return u1 . cxx
279328 raise ValueError (f"Invalid data type { dtype } " )
280329
281330
0 commit comments