Skip to content

Commit 53042ef

Browse files
committed
fix: remove unnecessary changes
1 parent d4b8fc1 commit 53042ef

File tree

11 files changed

+30
-137
lines changed

11 files changed

+30
-137
lines changed

infini_train/include/datatype.h

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -98,62 +98,17 @@ enum class DataType : int8_t {
9898
kFLOAT64,
9999
};
100100

101-
size_t DTypeSize(DataType data_type);
102-
103-
extern const std::unordered_map<DataType, std::string> kDataTypeToDesc;
104-
105-
// -----------------------------------------------------------------------------
106-
// Compile-time type mapping infrastructure
107-
// -----------------------------------------------------------------------------
108-
// Baseline framework scalar/storage mapping.
109-
// This is the single source of truth for:
110-
// - framework DataType -> C++ type mapping
111-
// - CPU default type mapping
112-
// - backend type-map fallback for dtypes without backend-native overrides
113-
template <DataType DType> struct TypeMap;
114-
115-
template <DataType DType> using TypeMap_t = typename TypeMap<DType>::type;
116-
117-
// -----------------------------------------------------------------------------
118-
// Compile-time reverse mapping: framework C++ type -> DataType
119-
// -----------------------------------------------------------------------------
120-
template <typename T> struct DataTypeMap;
121-
122-
template <typename T> inline constexpr DataType DataTypeMap_v = DataTypeMap<T>::value;
123-
124-
// Macro to define baseline mapping + reverse mapping
125-
#define DEFINE_DEFAULT_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \
126-
template <> struct TypeMap<DataType::ENUM_VALUE> { \
127-
using type = CPP_TYPE; \
128-
}; \
129-
template <> struct DataTypeMap<CPP_TYPE> { \
130-
static constexpr DataType value = DataType::ENUM_VALUE; \
131-
};
132-
133-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT8, uint8_t)
134-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT8, int8_t)
135-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT16, uint16_t)
136-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT16, int16_t)
137-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT32, uint32_t)
138-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT32, int32_t)
139-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT64, uint64_t)
140-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT64, int64_t)
141-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT32, float)
142-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT64, double)
143-
144-
#undef DEFINE_DEFAULT_DATA_TYPE_MAPPING
145-
146-
// ---------------------------------------------------------------------------
147-
// Low-precision types: reverse mapping ONLY (DataTypeMap).
148-
// TypeMap<kFLOAT16> / TypeMap<kBFLOAT16> are intentionally NOT defined here.
149-
// Backend TypeMaps must explicitly provide these mappings; the default TypeMap
150-
// will static_assert at compile time if dispatch reaches an unmapped dtype.
151-
// ---------------------------------------------------------------------------
152-
template <> struct DataTypeMap<FP16> {
153-
static constexpr DataType value = DataType::kFLOAT16;
101+
inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
102+
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
103+
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
104+
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
154105
};
155-
template <> struct DataTypeMap<BF16> {
156-
static constexpr DataType value = DataType::kBFLOAT16;
106+
107+
inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
108+
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
109+
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
110+
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
111+
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
157112
};
158113

159114
// =============================================================================

infini_train/include/dispatcher.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#include "infini_train/include/autocast.h"
1010
#include "infini_train/include/device.h"
11-
// FIXEM(dcj): should not include this
12-
#include "infini_train/include/dtype_dispatch.h"
1311
#ifdef PROFILE_MODE
1412
#include "infini_train/include/profiler.h"
1513
#endif

infini_train/include/dtype_dispatch.h

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_
234234
if constexpr (IsDataTypeInList_v<DType, DataTypeList<AllowedDTypes...>>) { \
235235
static_assert(HasMappedType_v<TypeMap, DType>, \
236236
"TypeMap does not provide explicit mapping for this dtype. " \
237-
"If this is a backend dispatch, register the dtype in the backend TypeMap; " \
238-
"if this is DispatchFunc, the dtype is not supported by the default TypeMap."); \
237+
"Register the dtype in the backend TypeMap (e.g., CpuTypeMap / CudaTypeMap)."); \
239238
return std::forward<Functor>(func).template operator()<MappedType_t<TypeMap, DType>>( \
240239
std::forward<Args>(args)...); \
241240
} else { \
@@ -283,8 +282,7 @@ struct TypeMapDispatcher {
283282
if constexpr (IsDataTypeInList_v<DType, CurrentList>) { \
284283
static_assert(HasMappedType_v<TypeMap, DType>, \
285284
"TypeMap does not provide explicit mapping for this dtype. " \
286-
"If this is a backend dispatch, register the dtype in the backend TypeMap; " \
287-
"if this is DispatchFunc, the dtype is not supported by the default TypeMap."); \
285+
"Register the dtype in the backend TypeMap (e.g., CpuTypeMap / CudaTypeMap)."); \
288286
using T = MappedType_t<TypeMap, DType>; \
289287
return TypeMapDispatcher<TypeMap, Index + 1, AllowedListTuple, ResolvedTypes..., T>::call( \
290288
dtypes, std::forward<Functor>(func), context_identifier, std::forward<Args>(args)...); \
@@ -334,24 +332,4 @@ auto DispatchByTypeMap(const std::vector<DataType> &dtypes, Functor &&func, std:
334332
dtypes, std::forward<Functor>(func), context_identifier, std::forward<Args>(args)...);
335333
}
336334

337-
// -----------------------------------------------------------------------------
338-
// Default framework dispatch using TypeMap
339-
// -----------------------------------------------------------------------------
340-
// TypeMap only covers standard types (int/uint/float32/float64).
341-
// Low-precision types (FP16/BF16) are intentionally unmapped — use a
342-
// backend-specific dispatch (DispatchCudaFunc, DispatchCpuFunc, …) instead.
343-
// -----------------------------------------------------------------------------
344-
template <DataType... AllowedDTypes, typename Functor, typename... Args>
345-
auto DispatchFunc(DataType dtype, Functor &&func, std::string_view context_identifier = "", Args &&...args) {
346-
return DispatchByTypeMap<TypeMap, AllowedDTypes...>(dtype, std::forward<Functor>(func), context_identifier,
347-
std::forward<Args>(args)...);
348-
}
349-
350-
template <typename... AllowedTypeLists, typename Functor, typename... Args>
351-
auto DispatchFunc(const std::vector<DataType> &dtypes, Functor &&func, std::string_view context_identifier = "",
352-
Args &&...args) {
353-
return DispatchByTypeMap<TypeMap, AllowedTypeLists...>(dtypes, std::forward<Functor>(func), context_identifier,
354-
std::forward<Args>(args)...);
355-
}
356-
357335
} // namespace infini_train

infini_train/src/core/runtime/cuda/cuda_dispatch.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22

3-
#include <cuda_bf16.h>
4-
#include <cuda_fp16.h>
5-
63
#include <utility>
74
#include <vector>
85

6+
#include <cuda_bf16.h>
7+
#include <cuda_fp16.h>
8+
99
#include "infini_train/include/core/backend_type_map.h"
1010
#include "infini_train/include/dtype_dispatch.h"
1111

infini_train/src/datatype.cc

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -163,46 +163,6 @@ BF16 &BF16::operator++() {
163163
return *this;
164164
}
165165

166-
// -----------------------------------------------------------------------------
167-
// DataType metadata
168-
// -----------------------------------------------------------------------------
169-
size_t DTypeSize(DataType data_type) {
170-
switch (data_type) {
171-
case DataType::kUINT8:
172-
return 1;
173-
case DataType::kINT8:
174-
return 1;
175-
case DataType::kUINT16:
176-
return 2;
177-
case DataType::kINT16:
178-
return 2;
179-
case DataType::kUINT32:
180-
return 4;
181-
case DataType::kINT32:
182-
return 4;
183-
case DataType::kUINT64:
184-
return 8;
185-
case DataType::kINT64:
186-
return 8;
187-
case DataType::kBFLOAT16:
188-
return 2;
189-
case DataType::kFLOAT16:
190-
return 2;
191-
case DataType::kFLOAT32:
192-
return 4;
193-
case DataType::kFLOAT64:
194-
return 8;
195-
}
196-
return 0; // unreachable
197-
}
198-
199-
const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
200-
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
201-
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
202-
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
203-
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
204-
};
205-
206166
// -----------------------------------------------------------------------------
207167
// DataType-level promotion
208168
// -----------------------------------------------------------------------------
@@ -234,7 +194,7 @@ DataType PromoteDataTypes(DataType a, DataType b) {
234194
}
235195

236196
// Rule 3: same category — wider wins
237-
return DTypeSize(a) >= DTypeSize(b) ? a : b;
197+
return kDataTypeToSize.at(a) >= kDataTypeToSize.at(b) ? a : b;
238198
}
239199

240200
} // namespace infini_train

infini_train/src/kernels/cuda/concat.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
9090
const int64_t num_inputs = static_cast<int64_t>(inputs.size());
9191
const int64_t K_total = out_dims[dim];
9292

93+
// offsets records the sum of Ks
9394
// offsets[i] = sum_{j < i} K_j
9495
std::vector<int64_t> host_offsets(num_inputs + 1, 0);
9596
for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; }

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "infini_train/include/common/cuda/kernel_helper.cuh"
77
#include "infini_train/include/core/runtime/device_guard.h"
88
#include "infini_train/include/dispatcher.h"
9+
#include "infini_train/include/dtype_dispatch.h"
910
#include "infini_train/include/tensor.h"
1011

1112
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"

infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
6464
const size_t piece_numel = local_end - local_start;
6565
CHECK_GT(piece_numel, 0);
6666

67-
const size_t param_piece_offset_bytes = local_start * DTypeSize(bucket_param->Dtype());
68-
const size_t grad_piece_offset_bytes = local_start * DTypeSize(bucket_grad->Dtype());
67+
const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype());
68+
const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype());
6969

7070
auto param_piece = std::make_shared<Tensor>(*bucket_param, param_piece_offset_bytes,
7171
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});

infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::shared_ptr<Tensor> AllocateFlatBuffer(size_t num_elements, DataType data_ty
3636

3737
std::shared_ptr<Tensor> GetBufferView(const std::shared_ptr<Tensor> buffer, size_t start_in_elements,
3838
const std::vector<int64_t> &dims) {
39-
return std::make_shared<Tensor>(*buffer, start_in_elements * DTypeSize(buffer->Dtype()), dims);
39+
return std::make_shared<Tensor>(*buffer, start_in_elements * kDataTypeToSize.at(buffer->Dtype()), dims);
4040
};
4141

4242
std::vector<std::shared_ptr<Tensor>> ShardBuffer(const std::shared_ptr<Tensor> buffer, size_t ddp_world_size) {
@@ -451,7 +451,7 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype)
451451
// Remap param/grad pointers
452452
if (param_buffer_) {
453453
// FIXME(zbl): change tensor buffer
454-
param->SetData(*param_buffer_, param_start_index * DTypeSize(param_buffer_->Dtype()), true);
454+
param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true);
455455
}
456456

457457
auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims());

infini_train/src/nn/parallel/ddp/reducer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace {
1818
void CopyGradToBucket(const std::shared_ptr<Tensor> &grad, const std::shared_ptr<Tensor> &flat,
1919
size_t dst_elem_offset) {
2020
CHECK(grad && flat);
21-
const size_t element_size_in_bytes = DTypeSize(grad->Dtype());
21+
const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype());
2222
const size_t bytes = grad->NumElements() * element_size_in_bytes;
2323
char *dst = static_cast<char *>(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes;
2424
const void *src = grad->DataPtr();
@@ -33,7 +33,7 @@ void CopyGradToBucket(const std::shared_ptr<Tensor> &grad, const std::shared_ptr
3333
void CopyBucketToGrad(const std::shared_ptr<Tensor> &flat, const std::shared_ptr<Tensor> &grad,
3434
size_t src_elem_offset) {
3535
CHECK(grad && flat);
36-
const size_t element_size_in_bytes = DTypeSize(grad->Dtype());
36+
const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype());
3737
const size_t bytes = grad->NumElements() * element_size_in_bytes;
3838
const char *src = static_cast<const char *>(flat->DataPtr()) + src_elem_offset * element_size_in_bytes;
3939
void *dst = grad->DataPtr();
@@ -48,7 +48,7 @@ void CopyBucketToGrad(const std::shared_ptr<Tensor> &flat, const std::shared_ptr
4848
std::shared_ptr<Tensor> MakeGradView(const std::shared_ptr<Tensor> &contents, size_t offset_elems,
4949
const std::vector<int64_t> &dims) {
5050
// Return a view of contents (same chunk of memory)
51-
auto view = std::make_shared<Tensor>(*contents, offset_elems * DTypeSize(contents->Dtype()), dims);
51+
auto view = std::make_shared<Tensor>(*contents, offset_elems * kDataTypeToSize.at(contents->Dtype()), dims);
5252
return view;
5353
}
5454
} // namespace
@@ -118,7 +118,7 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
118118
}
119119
auto &state = it->second;
120120

121-
const size_t element_size_in_bytes = DTypeSize(tensor->Dtype());
121+
const size_t element_size_in_bytes = kDataTypeToSize.at(tensor->Dtype());
122122
const size_t bytes = tensor->NumElements() * element_size_in_bytes;
123123
const size_t cap = bucket_size_limits[state.limit_idx];
124124

0 commit comments

Comments
 (0)