Skip to content

Commit b092c8b

Browse files
committed
refactor: remove NativeScalar instead of using BackendTypeMap directly
1 parent 6490093 commit b092c8b

File tree

3 files changed

+17
-53
lines changed

3 files changed

+17
-53
lines changed

infini_train/include/core/backend_type_map.h

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@ namespace infini_train::core {
88
/**
99
* Backend type mapping: DataType -> backend-native dispatch type
1010
*
11-
* NativeScalar — maps framework low-precision scalar types (FP16/BF16) to
12-
* backend-native scalar types (__half / __nv_bfloat16).
13-
* Primary template intentionally undefined.
14-
* Each backend specializes only the types it supports.
15-
*
1611
* BackendTypeMap — maps DataType to the C++ type used by kernels/dispatch.
1712
* Primary template intentionally undefined — there is NO
1813
* default fallback to the framework TypeMap<DType>.
@@ -22,53 +17,20 @@ namespace infini_train::core {
2217
* call INFINI_REGISTER_STANDARD_BACKEND_TYPES(Dev)
2318
* once at file scope in the backend's dispatch header.
2419
* - Low-precision types (FP16, BF16):
25-
* specialize NativeScalar<Dev, infini_train::FP16/BF16>.
26-
* The generic partial specializations below then resolve
27-
* automatically via SFINAE-safe helper.
20+
* directly specialize BackendTypeMap<Dev, kFLOAT16/kBFLOAT16>
21+
* in the backend's dispatch header (the native scalar type
22+
* differs per backend, e.g. __half on CUDA).
2823
*
2924
* If a backend does not register a dtype, HasMappedType_v returns false and
3025
* DispatchByTypeMap fires a clear static_assert at compile time.
3126
*/
3227

33-
// -----------------------------------------------------------------------------
34-
// NativeScalar: framework scalar -> backend native scalar
35-
// Primary template intentionally undefined.
36-
// -----------------------------------------------------------------------------
37-
template <Device::DeviceType Dev, typename Scalar> struct NativeScalar;
38-
39-
template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type;
40-
4128
// -----------------------------------------------------------------------------
4229
// BackendTypeMap: DataType -> backend dispatch type
4330
// Primary template intentionally undefined — no TypeMap<DType> fallback.
4431
// -----------------------------------------------------------------------------
4532
template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap;
4633

47-
// -----------------------------------------------------------------------------
48-
// SFINAE-safe helper for low-precision type routing.
49-
// When NativeScalar<Dev, Scalar> is undefined, this struct has no `type`
50-
// member, making HasMappedType_v<..., kFLOAT16/kBFLOAT16> return false and
51-
// triggering the static_assert in dispatch rather than an opaque hard error.
52-
// -----------------------------------------------------------------------------
53-
namespace detail {
54-
55-
template <Device::DeviceType Dev, typename Scalar, typename = void>
56-
struct BackendLowPrecisionTypeHelper {}; // no `type` member when NativeScalar absent
57-
58-
template <Device::DeviceType Dev, typename Scalar>
59-
struct BackendLowPrecisionTypeHelper<Dev, Scalar, std::void_t<typename NativeScalar<Dev, Scalar>::type>> {
60-
using type = typename NativeScalar<Dev, Scalar>::type;
61-
};
62-
63-
} // namespace detail
64-
65-
// Low-precision partial specializations: generic over Dev, resolved via NativeScalar.
66-
template <Device::DeviceType Dev>
67-
struct BackendTypeMap<Dev, DataType::kFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::FP16> {};
68-
69-
template <Device::DeviceType Dev>
70-
struct BackendTypeMap<Dev, DataType::kBFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::BF16> {};
71-
7234
} // namespace infini_train::core
7335

7436
// -----------------------------------------------------------------------------
@@ -80,7 +42,9 @@ struct BackendTypeMap<Dev, DataType::kBFLOAT16> : detail::BackendLowPrecisionTyp
8042
//
8143
// INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kCUDA)
8244
//
83-
// FP16 and BF16 are NOT registered here — they are handled via NativeScalar.
45+
// FP16 and BF16 are NOT registered here — backends must specialize
46+
// BackendTypeMap<DEV, kFLOAT16/kBFLOAT16> directly with their native scalar
47+
// type (e.g. __half / __nv_bfloat16 on CUDA).
8448
// -----------------------------------------------------------------------------
8549
#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \
8650
namespace infini_train::core { \

infini_train/src/core/runtime/cpu/cpu_dispatch.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77
#include "infini_train/include/dtype_dispatch.h"
88

99
// -----------------------------------------------------------------------------
10-
// CPU NativeScalar specializations: FP16 -> FP16, BF16 -> BF16
10+
// CPU low-precision BackendTypeMap specializations:
11+
// FP16 -> infini_train::FP16, BF16 -> infini_train::BF16
1112
// CPU uses the framework wrapper types directly (host-side conversion).
1213
// -----------------------------------------------------------------------------
1314
namespace infini_train::core {
14-
template <> struct NativeScalar<Device::DeviceType::kCPU, infini_train::FP16> {
15+
template <> struct BackendTypeMap<Device::DeviceType::kCPU, DataType::kFLOAT16> {
1516
using type = infini_train::FP16;
1617
};
1718

18-
template <> struct NativeScalar<Device::DeviceType::kCPU, infini_train::BF16> {
19+
template <> struct BackendTypeMap<Device::DeviceType::kCPU, DataType::kBFLOAT16> {
1920
using type = infini_train::BF16;
2021
};
2122
} // namespace infini_train::core
2223

2324
// Register all standard (non-low-precision) dtypes for the CPU backend.
24-
// FP16/BF16 are handled above via NativeScalar specializations +
25-
// BackendTypeMap partial specializations in backend_type_map.h.
25+
// FP16/BF16 are registered explicitly above.
2626
INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kCPU)
2727

2828
namespace infini_train::core::cpu {

infini_train/src/core/runtime/cuda/cuda_dispatch.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010
#include "infini_train/include/dtype_dispatch.h"
1111

1212
// -----------------------------------------------------------------------------
13-
// CUDA NativeScalar specializations: FP16 -> __half, BF16 -> __nv_bfloat16
13+
// CUDA low-precision BackendTypeMap specializations:
14+
// FP16 -> __half, BF16 -> __nv_bfloat16
1415
// -----------------------------------------------------------------------------
1516
namespace infini_train::core {
16-
template <> struct NativeScalar<Device::DeviceType::kCUDA, infini_train::FP16> {
17+
template <> struct BackendTypeMap<Device::DeviceType::kCUDA, DataType::kFLOAT16> {
1718
using type = __half;
1819
};
1920

20-
template <> struct NativeScalar<Device::DeviceType::kCUDA, infini_train::BF16> {
21+
template <> struct BackendTypeMap<Device::DeviceType::kCUDA, DataType::kBFLOAT16> {
2122
using type = __nv_bfloat16;
2223
};
2324
} // namespace infini_train::core
2425

2526
// Register all standard (non-low-precision) dtypes for the CUDA backend.
26-
// FP16/BF16 are handled above via NativeScalar specializations +
27-
// BackendTypeMap partial specializations in backend_type_map.h.
27+
// FP16/BF16 are registered explicitly above with their CUDA-native scalar types.
2828
INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kCUDA)
2929

3030
namespace infini_train::core::cuda {
@@ -38,7 +38,7 @@ template <DataType DType> struct CudaTypeMap;
3838

3939
// Register all supported dtypes by delegating to BackendTypeMap<kCUDA, DType>.
4040
// Standard types come from INFINI_REGISTER_STANDARD_BACKEND_TYPES above;
41-
// FP16/BF16 come from BackendTypeMap partial specializations + NativeScalar.
41+
// FP16/BF16 come from the explicit BackendTypeMap specializations above.
4242
#define INFINI_REGISTER_CUDA_TYPEMAP(DTYPE) \
4343
template <> \
4444
struct CudaTypeMap<DataType::DTYPE> \

0 commit comments

Comments
 (0)