Skip to content

Commit 2db8496

Browse files
committed
test: add dtype_dispatch tests
1 parent 8604d2d commit 2db8496

File tree

7 files changed

+552
-109
lines changed

7 files changed

+552
-109
lines changed

CMakeLists.txt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,38 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207+
add_executable(test_scalar test/scalar/test_scalar.cc)
208+
link_infini_train_exe(test_scalar)
209+
210+
add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc)
211+
link_infini_train_exe(test_dtype_dispatch)
212+
213+
# Negative compile test: missing dtype registration must fail at compile time.
214+
set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE
215+
${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc)
216+
217+
try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED
218+
${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map
219+
SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}
220+
CMAKE_FLAGS
221+
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
222+
"-DCMAKE_CXX_STANDARD_REQUIRED=ON"
223+
"-DCMAKE_CXX_EXTENSIONS=OFF"
224+
"-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}"
225+
OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT
226+
)
227+
228+
if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
229+
message(FATAL_ERROR
230+
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
231+
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
232+
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
233+
endif()
234+
235+
add_custom_target(test_dtype_dispatch_compile_fail
236+
COMMAND ${CMAKE_COMMAND} -E echo
237+
"dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)."
238+
VERBATIM
239+
)
240+
241+
add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail)

docs/device_guard_design.md

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,7 @@
11
# Device Guard Design
2-
device 注册初版基建 pr:https://github.com/InfiniTensor/InfiniTrain/pull/103
2+
Device 注册机制是 InfiniTrain 面向多硬件后端的统一运行时抽象与插件化接入基础设施。
33

4-
## 1. 设计背景与目标
5-
6-
### 1.1 背景
7-
8-
InfiniTrain 需要长期支持:
9-
10-
- 多种设备类型(CPU/CUDA/国产芯片)
11-
- 多种运行时能力(stream、memory、blas、通信等)
12-
- 在不侵入上层逻辑的前提下进行后端扩展与替换
13-
14-
在实际工程中,如果设备相关逻辑散落在框架各个模块,会导致:
15-
16-
- `#ifdef USE_CUDA/USE_MUSA/...` 泛滥
17-
- 新硬件接入需要修改大量框架核心代码
18-
- 设备切换与资源管理缺乏统一语义
19-
20-
### 1.2 设计目标
21-
22-
InfiniTrain 的 device 注册机制设计目标是:
23-
24-
1. 统一抽象:将所有与设备相关的运行时行为抽象到一个统一接口中。
25-
2. 后端可插拔:新设备后端可通过注册机制接入,无需修改框架核心逻辑。
26-
3. RAII 语义清晰:设备切换、资源恢复具备严格的作用域。
27-
4. 最小上层侵入:上层模块(Tensor/Autograd/Module)只感知 DeviceGuard/DeviceGuardImpl,不感知具体后端实现。
28-
29-
## 2. 核心组件
4+
## 1. 核心组件
305

316
InfiniTrain 的 device 机制由三类核心组件构成:
327

@@ -54,7 +29,7 @@ InfiniTrain 的 device 机制由三类核心组件构成:
5429
| DeviceGuard | 管理 “当前在哪个 device 上” 的上下文语义(RAII),语义与 device index 绑定;负责 device 的保存/切换/恢复,并将具体 runtime 操作转发给对应的 DeviceGuardImpl。 |
5530
| DeviceGuardImpl | 管理 “在该类 device 上如何执行 runtime 操作”,语义与 device type 绑定;对外提供 设备管理查询、stream、blas、同步、内存 等运行时能力接口。 |
5631

57-
### 2.1 DeviceGuardImpl:运行时能力抽象(对外暴露)
32+
### 1.1 DeviceGuardImpl:运行时能力抽象(对外暴露)
5833

5934
DeviceGuardImpl 是 InfiniTrain 中 device runtime 能力的统一抽象接口,并且是框架内部对外暴露的能力接口,封装了所有与 device 相关的行为(待补充 event 相关接口):
6035

@@ -112,7 +87,7 @@ virtual void ResetMemPoolHighWatermarks(Device device) const;
11287
virtual std::pair<size_t, size_t> GetMemPoolPeakMB(Device device) const;
11388
```
11489
115-
### 2.2 DeviceGuard:RAII 前端接口
90+
### 1.2 DeviceGuard:RAII 前端接口
11691
11792
DeviceGuard 是设备上下文的 RAII 管理器,其职责严格限定为:
11893
@@ -133,7 +108,7 @@ DeviceGuard 不直接提供任何运行时能力接口。
133108
// 离开作用域后,自动恢复进入前的 device
134109
```
135110

136-
### 2.3 DeviceGuardImplRegistry:全局注册表
111+
### 1.3 DeviceGuardImplRegistry:全局注册表
137112

138113
`DeviceGuardImplRegistry`是 InfiniTrain 中用于管理 device runtime 后端实现的全局注册表,采用 singleton 模式,生命周期覆盖整个进程。
139114

@@ -143,9 +118,9 @@ DeviceGuard 不直接提供任何运行时能力接口。
143118
std::unordered_map<Device::DeviceType, std::unique_ptr<DeviceGuardImpl>> impls_;
144119
```
145120

146-
## 3. Runtime Capability 获取与使用范式
121+
## 2. Runtime Capability 获取与使用范式
147122

148-
### 3.1 获取入口
123+
### 2.1 获取入口
149124

150125
```C++
151126
DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type);
@@ -154,7 +129,7 @@ DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type);
154129
- 返回指定`DeviceType`的 DeviceGuardImpl
155130
- 若未注册对应 backend,直接报错
156131
157-
### 3.2 推荐使用模式(标准范式)
132+
### 2.2 推荐使用模式(标准范式)
158133
159134
```C++
160135
auto device = tensor->GetDevice();
@@ -184,9 +159,9 @@ std::vector<float> buffer(num_elements);
184159
} // <-- DeviceGuard 在此处析构,device 上下文被恢复
185160
```
186161

187-
## 4. Backend 注册机制(静态注册)
162+
## 3. Backend 注册机制(静态注册)
188163

189-
### 4.1 注册宏
164+
### 3.1 注册宏
190165

191166
```C++
192167
#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \
@@ -198,7 +173,7 @@ std::vector<float> buffer(num_elements);
198173
199174
采用静态变量 + lambda 在程序启动阶段完成注册。
200175
201-
### 4.2 使用示例(CUDA Backend)
176+
### 3.2 使用示例(CUDA Backend)
202177
203178
```C++
204179
class CudaGuardImpl : public DeviceGuardImpl {

docs/dtype_registry_design.md

Lines changed: 9 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,7 @@
11
# Low-Precision DType Abstraction & Backend Registration Design
2-
统一低精度类型抽象与后端显式注册 pr:https://github.com/InfiniTensor/InfiniTrain/pull/114
2+
低精度 dtype 抽象是 InfiniTrain 面向多后端的统一类型语义与显式注册基础设施。
33

4-
## 1. 背景与动机
5-
6-
InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有关于 16-bit 浮点的语义都直接绑定在 CUDA 原生类型 `__half` / `__nv_bfloat16` 上。这
7-
导致几个问题:
8-
9-
1. **框架代码被 `#ifdef USE_CUDA` 污染。**
10-
`infini_train/include/datatype.h``infini_train/src/nn/init.cc` 等通用模块都需要
11-
写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间
12-
切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与
13-
`kUINT16` 的反向映射产生歧义。
14-
2. **`TypeMap<DType>` 是「全后端共享」的单点表。**
15-
`TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,
16-
意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。
17-
3. **类型提升耦合具体后端类型。**
18-
旧的 `WidestType_t<T1, T2>` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出
19-
一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把
20-
「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。
21-
4. **静默 fallback 容易掩盖错误。**
22-
一旦某个后端忘记注册 BF16/FP16,旧实现会沉默地走到 `uint16_t` 路径,得到一个
23-
语义错误的内核,而不是显式报错。
24-
25-
本工作的目标是:
26-
27-
> **把 FP16/BF16 抽象成框架级类型**,让框架代码不再直接接触任何后端原生
28-
> 16-bit 类型;同时把后端 dtype → 本地标量的映射改成**显式注册**机制,未注册的类型在编译期就被拦截。
29-
30-
## 2. Design In One Diagram
4+
## 1. Design In One Diagram
315

326
```
337
framework code ──► FP16 / BF16 (datatype.h, 纯软件实现,提供基本转换操作)
@@ -46,10 +20,10 @@ kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
4620

4721
要点:
4822

49-
- 框架层不提供任何「DataType → C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
50-
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**只有后端显式特化并完成注册的 dtype 组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。
23+
- 框架层不提供任何「DataType → 后端 C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
24+
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**只有后端显式特化并完成注册的组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。
5125

52-
## 3. Core API
26+
## 2. Core API
5327

5428
| API | 位置 | 说明 |
5529
| --- | --- | --- |
@@ -59,55 +33,17 @@ kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
5933
| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 |
6034
| `DispatchCpuFunc / DispatchCudaFunc<AllowedDTypes...>` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap<TypeMap, AllowedDTypes...>`|
6135

62-
## 4. Scalar:框架层标量载体
63-
64-
`BackendTypeMap` 解决「DataType → 后端 C++ 类型」,但框架 API 还需要一种
65-
**DataType 无关** 的方式接收标量参数:目标 tensor 的 DataType 运行期才确定,API 不可能
66-
为每种数值类型都写重载,更不能把后端原生类型暴露给调用方。
67-
68-
为此引入 `Scalar`[scalar.h](../infini_train/include/scalar.h)):
69-
70-
- 固定存储:`double / int64_t / uint64_t` + `Kind` tag(`kBool / kDouble / kInt64 / kUInt64`)。
71-
- 隐式构造覆盖所有框架标量:整数按符号分入 `kInt64 / kUInt64`,全部浮点(含 `FP16 / BF16`)归一到 `kDouble``bool` 独立。
72-
- 唯一出口 `Scalar::to<T>()`,通过 `common::cpu::Cast<T>` 把存储值转换到 dispatch 选出的后端标量类型。
73-
74-
与其它抽象的边界:`BackendTypeMap` 管「DataType → 后端 C++ 类型」,`PromoteDataTypes`
75-
「DataType → DataType」,`Scalar` 管「数值 → 后端 C++ 类型」,三者正交;`Scalar` 本身不参与类型提升决策。
76-
77-
### 4.1 使用模式
78-
79-
`Tensor::Fill(Scalar)` 是这套抽象的第一个落地点。kernel 侧使用模式如下:
80-
81-
```cpp
82-
// kernels/cpu/fill.cc
83-
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
84-
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
85-
tensor->Dtype(),
86-
[=]<typename T>() {
87-
auto data = reinterpret_cast<T *>(tensor->DataPtr());
88-
const T v = scalar.to<T>(); // Scalar 在此完成「数值 → 后端 C++ 类型」映射
89-
std::fill(data, data + tensor->NumElements(), v);
90-
},
91-
"CPU Fill");
92-
}
93-
```
94-
95-
`DispatchCpuFunc` 经 `BackendTypeMap` 把 `DataType` 解析为 `T`;`Scalar::to<T>()`
96-
把用户传入值转换到该 `T`。
97-
98-
## 5. How To Add A New Backend
36+
## 3. How To Add A New Backend
9937

10038
按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`
10139

10240
1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h``dtype_dispatch.h`
10341
2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。
104-
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>`
105-
指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在
106-
编译期触发 `static_assert`。
42+
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>` 指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在编译期触发 `static_assert`
10743
4. 定义 `XxxTypeMap<DType>` 转发/继承到 `BackendTypeMap<kXxx, DType>`
10844
5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>`
10945

110-
### 最小示例
46+
### Example
11147

11248
```cpp
11349
// xxx_dispatch.h
@@ -134,7 +70,7 @@ auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Arg
13470
} // namespace infini_train::core::xxx
13571
```
13672
137-
## 6. Failure Modes
73+
## 4. Failure Modes
13874
13975
| 情形 | 表现 |
14076
| --- | --- |

infini_train/include/scalar.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,31 @@ struct Scalar {
3131
Scalar(FP16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
3232
Scalar(BF16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
3333

34+
// TODO(dcj): Scalar::to<T>() should remain a framework-level conversion API
35+
// and should not directly target backend-native types such as __nv_bfloat16
36+
// or __half.
37+
//
38+
// Today to<T>() delegates to common::cpu::Cast, which only has explicit
39+
// semantics for framework scalar types (e.g. FP16/BF16). When T is a
40+
// backend-native half type, it falls back to raw static_cast, which happens
41+
// to compile on CUDA (via implicit constructors) but is backend-dependent
42+
// and may fail on other platforms (e.g. MACA).
43+
//
44+
// More importantly, this creates inconsistent rounding paths:
45+
// - to<BF16>(): double -> float -> bf16
46+
// - to<__nv_bfloat16>(): double -> bf16
47+
// The two paths may yield different results due to double rounding.
48+
// See `test/dtype/test_scalar.cc` (`TestToHalfPrecisionConversions`) for
49+
// a similar example.
50+
//
51+
// Planned fix:
52+
// 1) keep Scalar::to<T>() restricted to framework/common scalar types
53+
// 2) introduce a standalone convert<To, From> utility for common
54+
// conversion semantics
55+
// 3) let kernel/backend code use a backend-specific scalar_cast<T>
56+
// helper for native types, routing half-precision conversions
57+
// through float to guarantee consistent two-step rounding on all
58+
// backends.
3459
template <typename T> T to() const {
3560
switch (kind) {
3661
case Kind::kBool:

0 commit comments

Comments
 (0)