@@ -8,11 +8,6 @@ namespace infini_train::core {
88/* *
99 * Backend type mapping: DataType -> backend-native dispatch type
1010 *
11- * NativeScalar — maps framework low-precision scalar types (FP16/BF16) to
12- * backend-native scalar types (__half / __nv_bfloat16).
13- * Primary template intentionally undefined.
14- * Each backend specializes only the types it supports.
15- *
1611 * BackendTypeMap — maps DataType to the C++ type used by kernels/dispatch.
1712 * Primary template intentionally undefined — there is NO
1813 * default fallback to the framework TypeMap<DType>.
@@ -22,53 +17,20 @@ namespace infini_train::core {
2217 * call INFINI_REGISTER_STANDARD_BACKEND_TYPES(Dev)
2318 * once at file scope in the backend's dispatch header.
2419 * - Low-precision types (FP16, BF16):
25- * specialize NativeScalar <Dev, infini_train::FP16/BF16>.
26- * The generic partial specializations below then resolve
27- * automatically via SFINAE-safe helper .
20+ * directly specialize BackendTypeMap <Dev, kFLOAT16/kBFLOAT16>
21+ * in the backend's dispatch header (the native scalar type
22+ * differs per backend, e.g. __half on CUDA) .
2823 *
2924 * If a backend does not register a dtype, HasMappedType_v returns false and
3025 * DispatchByTypeMap fires a clear static_assert at compile time.
3126 */
3227
33- // -----------------------------------------------------------------------------
34- // NativeScalar: framework scalar -> backend native scalar
35- // Primary template intentionally undefined.
36- // -----------------------------------------------------------------------------
37- template <Device::DeviceType Dev, typename Scalar> struct NativeScalar ;
38-
39- template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type;
40-
4128// -----------------------------------------------------------------------------
4229// BackendTypeMap: DataType -> backend dispatch type
4330// Primary template intentionally undefined — no TypeMap<DType> fallback.
4431// -----------------------------------------------------------------------------
4532template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap ;
4633
47- // -----------------------------------------------------------------------------
48- // SFINAE-safe helper for low-precision type routing.
49- // When NativeScalar<Dev, Scalar> is undefined, this struct has no `type`
50- // member, making HasMappedType_v<..., kFLOAT16/kBFLOAT16> return false and
51- // triggering the static_assert in dispatch rather than an opaque hard error.
52- // -----------------------------------------------------------------------------
53- namespace detail {
54-
55- template <Device::DeviceType Dev, typename Scalar, typename = void >
56- struct BackendLowPrecisionTypeHelper {}; // no `type` member when NativeScalar absent
57-
58- template <Device::DeviceType Dev, typename Scalar>
59- struct BackendLowPrecisionTypeHelper <Dev, Scalar, std::void_t <typename NativeScalar<Dev, Scalar>::type>> {
60- using type = typename NativeScalar<Dev, Scalar>::type;
61- };
62-
63- } // namespace detail
64-
65- // Low-precision partial specializations: generic over Dev, resolved via NativeScalar.
66- template <Device::DeviceType Dev>
67- struct BackendTypeMap <Dev, DataType::kFLOAT16 > : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::FP16> {};
68-
69- template <Device::DeviceType Dev>
70- struct BackendTypeMap <Dev, DataType::kBFLOAT16 > : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::BF16> {};
71-
7234} // namespace infini_train::core
7335
7436// -----------------------------------------------------------------------------
@@ -80,7 +42,9 @@ struct BackendTypeMap<Dev, DataType::kBFLOAT16> : detail::BackendLowPrecisionTyp
8042//
8143// INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kCUDA)
8244//
83- // FP16 and BF16 are NOT registered here — they are handled via NativeScalar.
45+ // FP16 and BF16 are NOT registered here — backends must specialize
46+ // BackendTypeMap<DEV, kFLOAT16/kBFLOAT16> directly with their native scalar
47+ // type (e.g. __half / __nv_bfloat16 on CUDA).
8448// -----------------------------------------------------------------------------
8549#define INFINI_REGISTER_STANDARD_BACKEND_TYPES (DEV ) \
8650 namespace infini_train ::core { \
0 commit comments