Skip to content

Commit fc361c7

Browse files
committed
feat: add bool datatype
1 parent bfb5924 commit fc361c7

14 files changed

Lines changed: 51 additions & 38 deletions

File tree

infini_train/include/core/backend_type_map.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap;
4848
// -----------------------------------------------------------------------------
4949
#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \
5050
namespace infini_train::core { \
51+
template <> struct BackendTypeMap<DEV, DataType::kBOOL> { \
52+
using type = bool; \
53+
}; \
5154
template <> struct BackendTypeMap<DEV, DataType::kUINT8> { \
5255
using type = uint8_t; \
5356
}; \

infini_train/include/datatype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct alignas(2) BF16 {
8484
// DataType enum and metadata tables
8585
// -----------------------------------------------------------------------------
8686
enum class DataType : int8_t {
87+
kBOOL,
8788
kUINT8,
8889
kINT8,
8990
kUINT16,
@@ -99,12 +100,14 @@ enum class DataType : int8_t {
99100
};
100101

101102
inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
103+
{DataType::kBOOL, 1},
102104
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
103105
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
104106
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
105107
};
106108

107109
inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
110+
{DataType::kBOOL, "bool"},
108111
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
109112
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
110113
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},

infini_train/include/dtype_dispatch.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,11 @@ namespace infini_train {
180180
#define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64
181181
#define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16
182182
#define INFINI_ALL_FLOATING_TYPES INFINI_FLOATING_TYPES, INFINI_REDUCED_FLOATING_TYPES
183+
#define INFINI_LOGICAL_TYPES DataType::kBOOL
183184
#define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64
184185
#define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64
185186
#define INFINI_ALL_INTEGRAL_TYPES INFINI_SIGNED_INTEGRAL_TYPES, INFINI_UNSIGNED_INTEGRAL_TYPES
186-
#define INFINI_ALL_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES
187+
#define INFINI_ALL_NUMERIC_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES
187188
#define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8
188189
#define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16
189190
#define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32
@@ -242,6 +243,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_
242243
} \
243244
}
244245

246+
CASE_FOR_TYPE(DataType::kBOOL)
245247
CASE_FOR_TYPE(DataType::kUINT8)
246248
CASE_FOR_TYPE(DataType::kINT8)
247249
CASE_FOR_TYPE(DataType::kUINT16)
@@ -290,6 +292,7 @@ struct TypeMapDispatcher {
290292
break; \
291293
}
292294

295+
CASE_FOR_TYPE(DataType::kBOOL)
293296
CASE_FOR_TYPE(DataType::kUINT8)
294297
CASE_FOR_TYPE(DataType::kINT8)
295298
CASE_FOR_TYPE(DataType::kUINT16)

infini_train/src/kernels/cpu/cast.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
1313
auto device = input->GetDevice();
1414
auto dst_tensor = std::make_shared<Tensor>(input->Dims(), dtype, device);
1515

16-
core::cpu::DispatchCpuFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
16+
core::cpu::DispatchCpuFunc<DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>,
17+
DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>>(
1718
{dtype, input->Dtype()},
1819
[=]<typename Tdst, typename Tsrc>() {
1920
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());

infini_train/src/kernels/cpu/fill.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace infini_train::kernels::cpu {
1010
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
11-
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
11+
core::cpu::DispatchCpuFunc<INFINI_ALL_NUMERIC_TYPES>(
1212
tensor->Dtype(),
1313
[=]<typename T>() {
1414
auto data = reinterpret_cast<T *>(tensor->DataPtr());

infini_train/src/kernels/cuda/cast.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
3434
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
3535
const size_t step = grid_dims.x * block_dims.x;
3636

37-
core::cuda::DispatchCudaFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
37+
core::cuda::DispatchCudaFunc<DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>,
38+
DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>>(
3839
{dtype, input->Dtype()},
3940
[=]<typename Tdst, typename Tsrc>() {
4041
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());

infini_train/src/kernels/cuda/concat.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
103103
int threads_per_block = 256;
104104
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);
105105

106-
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
106+
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
107107
dtype,
108108
[=, &inputs, &host_offsets]<typename T>() {
109109
std::vector<const T *> host_input_ptrs;
@@ -208,7 +208,7 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
208208
int threads_per_block = 256;
209209
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);
210210

211-
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
211+
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
212212
dtype,
213213
[=, &grads, &host_offsets]<typename T>() {
214214
std::vector<T *> host_ptrs;

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ std::shared_ptr<Tensor> EqualsForward(const std::shared_ptr<Tensor> &a, const st
10181018
DISPATCH(a->Dtype(),
10191019
return BinaryForward(a, b,
10201020
[] __device__(auto x, auto y) { return (x == y) ? decltype(x){1} : decltype(x){0}; });
1021-
, INFINI_ALL_TYPES)
1021+
, INFINI_ALL_NUMERIC_TYPES)
10221022
}
10231023

10241024
std::shared_ptr<Tensor> EqualsScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1033,7 +1033,7 @@ std::shared_ptr<Tensor> EqualsScalarForward(const std::shared_ptr<Tensor> &a, fl
10331033
std::shared_ptr<Tensor> LtForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10341034
DISPATCH(a->Dtype(), return BinaryForward(
10351035
a, b, [] __device__(auto x, auto y) { return x < y ? decltype(x){1} : decltype(x){0}; });
1036-
, INFINI_ALL_TYPES)
1036+
, INFINI_ALL_NUMERIC_TYPES)
10371037
}
10381038

10391039
std::shared_ptr<Tensor> LtScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1042,14 +1042,14 @@ std::shared_ptr<Tensor> LtScalarForward(const std::shared_ptr<Tensor> &a, float
10421042
return (x < static_cast<decltype(x)>(scalar)) ? decltype(x){1}
10431043
: decltype(x){0};
10441044
});
1045-
, INFINI_ALL_TYPES)
1045+
, INFINI_ALL_NUMERIC_TYPES)
10461046
}
10471047

10481048
std::shared_ptr<Tensor> LeForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10491049
DISPATCH(a->Dtype(),
10501050
return BinaryForward(a, b,
10511051
[] __device__(auto x, auto y) { return (x <= y) ? decltype(x){1} : decltype(x){0}; });
1052-
, INFINI_ALL_TYPES)
1052+
, INFINI_ALL_NUMERIC_TYPES)
10531053
}
10541054

10551055
std::shared_ptr<Tensor> LeScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1058,13 +1058,13 @@ std::shared_ptr<Tensor> LeScalarForward(const std::shared_ptr<Tensor> &a, float
10581058
return (x <= static_cast<decltype(x)>(scalar)) ? decltype(x){1}
10591059
: decltype(x){0};
10601060
});
1061-
, INFINI_ALL_TYPES)
1061+
, INFINI_ALL_NUMERIC_TYPES)
10621062
}
10631063

10641064
std::shared_ptr<Tensor> GtForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10651065
DISPATCH(a->Dtype(), return BinaryForward(
10661066
a, b, [] __device__(auto x, auto y) { return x > y ? decltype(x){1} : decltype(x){0}; });
1067-
, INFINI_ALL_TYPES)
1067+
, INFINI_ALL_NUMERIC_TYPES)
10681068
}
10691069

10701070
std::shared_ptr<Tensor> GtScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1073,14 +1073,14 @@ std::shared_ptr<Tensor> GtScalarForward(const std::shared_ptr<Tensor> &a, float
10731073
return (x > static_cast<decltype(x)>(scalar)) ? decltype(x){1}
10741074
: decltype(x){0};
10751075
});
1076-
, INFINI_ALL_TYPES)
1076+
, INFINI_ALL_NUMERIC_TYPES)
10771077
}
10781078

10791079
std::shared_ptr<Tensor> GeForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
10801080
DISPATCH(a->Dtype(),
10811081
return BinaryForward(a, b,
10821082
[] __device__(auto x, auto y) { return (x >= y) ? decltype(x){1} : decltype(x){0}; });
1083-
, INFINI_ALL_TYPES)
1083+
, INFINI_ALL_NUMERIC_TYPES)
10841084
}
10851085

10861086
std::shared_ptr<Tensor> GeScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
@@ -1089,7 +1089,7 @@ std::shared_ptr<Tensor> GeScalarForward(const std::shared_ptr<Tensor> &a, float
10891089
return (x >= static_cast<decltype(x)>(scalar)) ? decltype(x){1}
10901090
: decltype(x){0};
10911091
});
1092-
, INFINI_ALL_TYPES)
1092+
, INFINI_ALL_NUMERIC_TYPES)
10931093
}
10941094

10951095
std::shared_ptr<Tensor> OrForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1098,7 +1098,7 @@ std::shared_ptr<Tensor> OrForward(const std::shared_ptr<Tensor> &a, const std::s
10981098
return (x != decltype(x){0} || y != decltype(y){0}) ? decltype(x){1}
10991099
: decltype(x){0};
11001100
});
1101-
, INFINI_ALL_TYPES)
1101+
, INFINI_ALL_NUMERIC_TYPES)
11021102
}
11031103

11041104
std::shared_ptr<Tensor> AndForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1107,7 +1107,7 @@ std::shared_ptr<Tensor> AndForward(const std::shared_ptr<Tensor> &a, const std::
11071107
return (x != decltype(x){0} && y != decltype(y){0}) ? decltype(x){1}
11081108
: decltype(x){0};
11091109
});
1110-
, INFINI_ALL_TYPES)
1110+
, INFINI_ALL_NUMERIC_TYPES)
11111111
}
11121112

11131113
std::shared_ptr<Tensor> AddForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
@@ -1125,19 +1125,19 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> AddBackward(const st
11251125
std::shared_ptr<Tensor> AddScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
11261126
DISPATCH(a->Dtype(),
11271127
return UnaryForward(a, [scalar] __device__(auto x) { return Add(x, static_cast<decltype(x)>(scalar)); });
1128-
, INFINI_ALL_TYPES)
1128+
, INFINI_ALL_NUMERIC_TYPES)
11291129
}
11301130

11311131
std::shared_ptr<Tensor> AddScalarBackward(const std::shared_ptr<Tensor> &grad_output) {
11321132
DISPATCH(grad_output->Dtype(),
11331133
return UnaryBackward(grad_output, nullptr,
11341134
[] __device__(auto x) { return common::cuda::Cast<decltype(x)>(1); });
1135-
, INFINI_ALL_TYPES)
1135+
, INFINI_ALL_NUMERIC_TYPES)
11361136
}
11371137

11381138
std::shared_ptr<Tensor> SubForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
11391139
DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return Sub(x, y); });
1140-
, INFINI_ALL_TYPES)
1140+
, INFINI_ALL_NUMERIC_TYPES)
11411141
}
11421142

11431143
std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> SubBackward(const std::shared_ptr<Tensor> &grad_output,

infini_train/src/kernels/cuda/fill.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
2828
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
2929
->cuda_stream();
3030

31-
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
31+
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
3232
tensor->Dtype(),
3333
[=]<typename T>() {
3434
const T casted_value = scalar.to<T>();

infini_train/src/kernels/cuda/slice.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
9292
int threads_per_block = 256;
9393
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
9494

95-
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
95+
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
9696
dtype,
9797
[=]<typename T>() {
9898
SliceForwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
@@ -185,7 +185,7 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
185185
int threads_per_block = 256;
186186
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
187187

188-
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
188+
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
189189
grad_output_dtype,
190190
[=]<typename T>() {
191191
SliceBackwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(

0 commit comments

Comments
 (0)