Skip to content

Commit fda7ccb

Browse files
committed
docs: add device_guard_design & dtype_registry_design docs
1 parent 12a359a commit fda7ccb

File tree

2 files changed

+352
-0
lines changed

2 files changed

+352
-0
lines changed

docs/device_guard_design.md

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Device Guard Design
2+
device 注册初版基建 pr:https://github.com/InfiniTensor/InfiniTrain/pull/103
3+
4+
## 1. 设计背景与目标
5+
6+
### 1.1 背景
7+
8+
InfiniTrain 需要长期支持:
9+
10+
- 多种设备类型(CPU/CUDA/国产芯片)
11+
- 多种运行时能力(stream、memory、blas、通信等)
12+
- 在不侵入上层逻辑的前提下进行后端扩展与替换
13+
14+
在实际工程中,如果设备相关逻辑散落在框架各个模块,会导致:
15+
16+
- `#ifdef USE_CUDA/USE_MUSA/...` 泛滥
17+
- 新硬件接入需要修改大量框架核心代码
18+
- 设备切换与资源管理缺乏统一语义
19+
20+
### 1.2 设计目标
21+
22+
InfiniTrain 的 device 注册机制设计目标是:
23+
24+
1. 统一抽象:将所有与设备相关的运行时行为抽象到一个统一接口中。
25+
2. 后端可插拔:新设备后端可通过注册机制接入,无需修改框架核心逻辑。
26+
3. RAII 语义清晰:设备切换、资源恢复具备严格的作用域。
27+
4. 最小上层侵入:上层模块(Tensor/Autograd/Module)只感知 DeviceGuard/DeviceGuardImpl,不感知具体后端实现。
28+
29+
## 2. 核心组件
30+
31+
InfiniTrain 的 device 机制由三类核心组件构成:
32+
33+
```C++
34+
+-------------------+
35+
| DeviceGuard | ← 对外 RAII 接口(public)
36+
+-------------------+
37+
|
38+
v
39+
+-------------------+
40+
| DeviceGuardImpl | ← 后端抽象接口(virtual
41+
+-------------------+
42+
^
43+
|
44+
+-------------------+
45+
| DeviceGuardImpl |
46+
| Registry | ← 全局注册表(singleton)
47+
+-------------------+
48+
```
49+
50+
其中 DeviceGuard 与 DeviceGuardImpl 的关系是:
51+
52+
| 组件 | 职责 |
53+
| --------------- | ------------------------------------------------------------ |
54+
| DeviceGuard | 管理 “当前在哪个 device 上” 的上下文语义(RAII),语义与 device index 绑定;负责 device 的保存/切换/恢复,并将具体 runtime 操作转发给对应的 DeviceGuardImpl。 |
55+
| DeviceGuardImpl | 管理 “在该类 device 上如何执行 runtime 操作”,语义与 device type 绑定;对外提供 设备管理查询、stream、blas、同步、内存 等运行时能力接口。 |
56+
57+
### 2.1 DeviceGuardImpl:运行时能力抽象(对外暴露)
58+
59+
DeviceGuardImpl 是 InfiniTrain 中 device runtime 能力的统一抽象接口,并且是框架内部对外暴露的能力接口,封装了所有与 device 相关的行为(待补充 event 相关接口):
60+
61+
```C++
62+
// ----------------------------------------------------------------------
63+
// Device management
64+
// ----------------------------------------------------------------------
65+
66+
virtual Device GetDevice() const = 0;
67+
68+
virtual void SetDevice(Device device) const;
69+
70+
virtual int8_t DeviceCount() const;
71+
72+
virtual Device::DeviceType Type() const = 0;
73+
74+
// ----------------------------------------------------------------------
75+
// Stream management
76+
// ----------------------------------------------------------------------
77+
78+
virtual Stream *GetStream(Device) const;
79+
80+
// ----------------------------------------------------------------------
81+
// Synchronization
82+
// ----------------------------------------------------------------------
83+
84+
virtual void SynchronizeDevice(Device) const;
85+
86+
virtual void SynchronizeStream(Stream *) const;
87+
88+
// ----------------------------------------------------------------------
89+
// BLAS handle
90+
// ----------------------------------------------------------------------
91+
92+
virtual BlasHandle *GetBlasHandle(Device) const;
93+
94+
// ----------------------------------------------------------------------
95+
// Memory operations
96+
// ----------------------------------------------------------------------
97+
98+
virtual void Malloc(void **dev_ptr, size_t size) = 0;
99+
100+
virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream);
101+
102+
virtual void Free(void *dev_ptr) = 0;
103+
104+
virtual void FreeAsync(void *dev_ptr, Stream *stream);
105+
106+
virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0;
107+
108+
virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream);
109+
110+
virtual void ResetMemPoolHighWatermarks(Device device) const;
111+
112+
virtual std::pair<size_t, size_t> GetMemPoolPeakMB(Device device) const;
113+
```
114+
115+
### 2.2 DeviceGuard:RAII 前端接口
116+
117+
DeviceGuard 是设备上下文的 RAII 管理器,其职责严格限定为:
118+
119+
- 保存当前 device
120+
- 切换到目标 device
121+
- 在作用域结束时恢复原 device
122+
123+
DeviceGuard 不直接提供任何运行时能力接口。
124+
125+
使用示例:
126+
127+
```C++
128+
{
129+
DeviceGuard guard(Device(DeviceType::kCUDA, 1));
130+
// 当前线程的 device 上下文被切换到 CUDA:1
131+
// 所有 runtime 操作将发生在 CUDA:1
132+
}
133+
// 离开作用域后,自动恢复进入前的 device
134+
```
135+
136+
### 2.3 DeviceGuardImplRegistry:全局注册表
137+
138+
`DeviceGuardImplRegistry`是 InfiniTrain 中用于管理 device runtime 后端实现的全局注册表,采用 singleton 模式,生命周期覆盖整个进程。
139+
140+
其核心职责是维护`DeviceType -> DeviceGuardImpl`的一对一映射关系:
141+
142+
```C++
143+
std::unordered_map<Device::DeviceType, std::unique_ptr<DeviceGuardImpl>> impls_;
144+
```
145+
146+
## 3. Runtime Capability 获取与使用范式
147+
148+
### 3.1 获取入口
149+
150+
```C++
151+
DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type);
152+
```
153+
154+
- 返回指定`DeviceType`的 DeviceGuardImpl
155+
- 若未注册对应 backend,直接报错
156+
157+
### 3.2 推荐使用模式(标准范式)
158+
159+
```C++
160+
auto device = tensor->GetDevice();
161+
const int64_t num_elements = tensor->NumElements();
162+
std::vector<float> buffer(num_elements);
163+
164+
{
165+
// 1. 切换 device 上下文(RAII scope)
166+
core::DeviceGuard guard(device);
167+
168+
// 2. 获取 runtime capability
169+
auto* impl = core::GetDeviceGuardImpl(device.type());
170+
171+
// 3. 执行 runtime 操作
172+
const core::MemcpyKind kind =
173+
device.type() == Device::DeviceType::kCPU
174+
? core::MemcpyKind::kD2D // CPU: host-host memcpy
175+
: core::MemcpyKind::kH2D; // Device: host-device copy
176+
177+
impl->MemcpyAsync(
178+
tensor->DataPtr(), // dst
179+
buffer.data(), // src
180+
num_elements * sizeof(float), // count
181+
kind, // kind(说明:在 CPU backend 中,kD2D 对应普通 memcpy)
182+
impl->GetStream(device) // stream
183+
);
184+
} // <-- DeviceGuard 在此处析构,device 上下文被恢复
185+
```
186+
187+
## 4. Backend 注册机制(静态注册)
188+
189+
### 4.1 注册宏
190+
191+
```C++
192+
#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \
193+
static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \
194+
infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique<class_impl>()); \
195+
return true; \
196+
}();
197+
```
198+
199+
采用静态变量 + lambda 在程序启动阶段完成注册。
200+
201+
### 4.2 使用示例(CUDA Backend)
202+
203+
```C++
204+
class CudaGuardImpl : public DeviceGuardImpl {
205+
...
206+
};
207+
208+
INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl)
209+
```
210+

docs/dtype_registry_design.md

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Low-Precision DType Abstraction & Backend Registration Design
2+
统一低精度类型抽象与后端显式注册 pr:https://github.com/InfiniTensor/InfiniTrain/pull/114
3+
4+
## 1. 背景与动机
5+
6+
InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有关于 16-bit 浮点的语义都直接绑定在 CUDA 原生类型 `__half` / `__nv_bfloat16` 上。这
7+
导致几个问题:
8+
9+
1. **框架代码被 `#ifdef USE_CUDA` 污染。**
10+
`infini_train/include/datatype.h``infini_train/src/nn/init.cc` 等通用模块都需要
11+
写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间
12+
切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与
13+
`kUINT16` 的反向映射产生歧义。
14+
2. **`TypeMap<DType>` 是「全后端共享」的单点表。**
15+
`TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,
16+
意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。
17+
3. **类型提升耦合具体后端类型。**
18+
旧的 `WidestType_t<T1, T2>` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出
19+
一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把
20+
「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。
21+
4. **静默 fallback 容易掩盖错误。**
22+
一旦某个后端忘记注册 BF16/FP16,旧实现会沉默地走到 `uint16_t` 路径,得到一个
23+
语义错误的内核,而不是显式报错。
24+
25+
本工作的目标是:
26+
27+
> **把 FP16/BF16 抽象成框架级类型**,让框架代码不再直接接触任何后端原生
28+
> 16-bit 类型;同时把后端 dtype → 本地标量的映射改成**显式注册**机制,未注册的类型在编译期就被拦截。
29+
30+
## 2. Design In One Diagram
31+
32+
```
33+
framework code ──► FP16 / BF16 (datatype.h, 纯软件实现,提供基本转换操作)
34+
PromoteDataTypes(DataType, DataType)
35+
36+
kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
37+
38+
39+
BackendTypeMap<Dev, DType> (主模板只声明不定义)
40+
41+
├─ kFLOAT16 / kBFLOAT16 → 后端在 *_dispatch.h 显式特化后注册
42+
│ └── CUDA: __half / __nv_bfloat16
43+
│ └── CPU : FP16 / BF16
44+
└─ 其它 10 个标量 dtype 使用默认注册 → INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)
45+
```
46+
47+
要点:
48+
49+
- 框架层不提供任何「DataType → C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
50+
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**,只有后端显式特化并完成注册的 dtype 组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。
51+
52+
## 3. Core API
53+
54+
| API | 位置 | 说明 |
55+
| --- | --- | --- |
56+
| `struct FP16 / BF16` | [datatype.h](../infini_train/include/datatype.h) | 16-bit 软件包装(IEEE-754 half / truncated bf16),承担框架身份、存储布局、fallback 转换;不承担后端高性能算术语义。 |
57+
| `PromoteDataTypes(DataType, DataType)` | [datatype.h](../infini_train/include/datatype.h) | 纯枚举到枚举的类型提升。规则:FP16+BF16→FP32;浮点优先于整数;同类按字节宽取大。 |
58+
| `BackendTypeMap<Dev, DType>` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 主模板**只声明不定义**;后端通过显式特化提供 `::type`|
59+
| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 |
60+
| `DispatchCpuFunc / DispatchCudaFunc<AllowedDTypes...>` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap<TypeMap, AllowedDTypes...>`|
61+
62+
## 4. Scalar:框架层标量载体
63+
64+
`BackendTypeMap` 解决「DataType → 后端 C++ 类型」,但框架 API 还需要一种
65+
**DataType 无关** 的方式接收标量参数:目标 tensor 的 DataType 运行期才确定,API 不可能
66+
为每种数值类型都写重载,更不能把后端原生类型暴露给调用方。
67+
68+
为此引入 `Scalar`[scalar.h](../infini_train/include/scalar.h)):
69+
70+
- 固定存储:`double / int64_t / uint64_t` + `Kind` tag(`kBool / kDouble / kInt64 / kUInt64`)。
71+
- 隐式构造覆盖所有框架标量:整数按符号分入 `kInt64 / kUInt64`,全部浮点(含 `FP16 / BF16`)归一到 `kDouble``bool` 独立。
72+
- 唯一出口 `Scalar::to<T>()`,通过 `common::cpu::Cast<T>` 把存储值转换到 dispatch 选出的后端标量类型。
73+
74+
与其它抽象的边界:`BackendTypeMap` 管「DataType → 后端 C++ 类型」,`PromoteDataTypes`
75+
「DataType → DataType」,`Scalar` 管「数值 → 后端 C++ 类型」,三者正交;`Scalar` 本身不参与类型提升决策。
76+
77+
### 4.1 使用模式
78+
79+
`Tensor::Fill(Scalar)` 是这套抽象的第一个落地点。kernel 侧使用模式如下:
80+
81+
```cpp
82+
// kernels/cpu/fill.cc
83+
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
84+
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
85+
tensor->Dtype(),
86+
[=]<typename T>() {
87+
auto data = reinterpret_cast<T *>(tensor->DataPtr());
88+
const T v = scalar.to<T>(); // Scalar 在此完成「数值 → 后端 C++ 类型」映射
89+
std::fill(data, data + tensor->NumElements(), v);
90+
},
91+
"CPU Fill");
92+
}
93+
```
94+
95+
`DispatchCpuFunc` 经 `BackendTypeMap` 把 `DataType` 解析为 `T`;`Scalar::to<T>()`
96+
把用户传入值转换到该 `T`。
97+
98+
## 5. How To Add A New Backend
99+
100+
按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`:
101+
102+
1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h` 与 `dtype_dispatch.h`。
103+
2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。
104+
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>`
105+
指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在
106+
编译期触发 `static_assert`。
107+
4. 定义 `XxxTypeMap<DType>` 转发/继承到 `BackendTypeMap<kXxx, DType>`。
108+
5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>`。
109+
110+
### 最小示例
111+
112+
```cpp
113+
// xxx_dispatch.h
114+
#include "infini_train/include/core/backend_type_map.h"
115+
#include "infini_train/include/dtype_dispatch.h"
116+
117+
namespace infini_train::core {
118+
// 若硬件支持低精度,显式特化 FP16/BF16
119+
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kFLOAT16> { using type = xxx_half; };
120+
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kBFLOAT16> { using type = xxx_bfloat; };
121+
} // namespace infini_train::core
122+
123+
INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kXxx)
124+
125+
namespace infini_train::core::xxx {
126+
template <DataType DType>
127+
struct XxxTypeMap : BackendTypeMap<Device::DeviceType::kXxx, DType> {};
128+
129+
template <DataType... AllowedDTypes, typename Functor, typename... Args>
130+
auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Args &&...a) {
131+
return DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>(
132+
dtype, std::forward<Functor>(f), ctx, std::forward<Args>(a)...);
133+
}
134+
} // namespace infini_train::core::xxx
135+
```
136+
137+
## 6. Failure Modes
138+
139+
| 情形 | 表现 |
140+
| --- | --- |
141+
| 后端未注册某个 dtype(`BackendTypeMap<Dev, DType>` 无特化),但被 dispatch 命中 | 编译期 `static_assert` 触发,错误信息指向 `BackendTypeMap` 的显式注册要求。 |
142+
| dispatch 的 dtype 不在调用点 `AllowedDTypes...` 白名单内 | 运行期 `LOG_UNSUPPORTED_DTYPE` 报错。 |

0 commit comments

Comments
 (0)