Skip to content

Commit f0835b2

Browse files
committed
test: migrate test_transformer_architecture to ctest framework
1 parent dbdf569 commit f0835b2

7 files changed

Lines changed: 178 additions & 646 deletions

File tree

infini_train/include/datatype.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,18 @@ enum class DataType : int8_t {
100100
};
101101

102102
inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
103-
{DataType::kBOOL, 1},
104-
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
105-
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
106-
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
103+
{DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2},
104+
{DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8},
105+
{DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4},
106+
{DataType::kFLOAT64, 8},
107107
};
108108

109109
inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
110-
{DataType::kBOOL, "bool"},
111-
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
112-
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
113-
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
114-
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
110+
{DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"},
111+
{DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"},
112+
{DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"},
113+
{DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"},
114+
{DataType::kFLOAT64, "fp64"},
115115
};
116116

117117
// =============================================================================

infini_train/src/kernels/cpu/gather.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace infini_train::kernels::cpu {
1111
std::shared_ptr<Tensor> GatherForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &index,
12-
int64_t dim) {
12+
int64_t dim) {
1313
const auto &in_dims = input->Dims();
1414
const auto &idx_dims = index->Dims();
1515
CHECK_EQ(in_dims.size(), idx_dims.size());
@@ -100,9 +100,8 @@ std::shared_ptr<Tensor> GatherForward(const std::shared_ptr<Tensor> &input, cons
100100
return out;
101101
}
102102

103-
std::shared_ptr<Tensor> GatherBackward(const std::shared_ptr<Tensor> &grad_output,
104-
const std::shared_ptr<Tensor> &index, int64_t dim,
105-
const std::vector<int64_t> &input_dims) {
103+
std::shared_ptr<Tensor> GatherBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr<Tensor> &index,
104+
int64_t dim, const std::vector<int64_t> &input_dims) {
106105
const auto &in_dims = input_dims;
107106
const auto &idx_dims = index->Dims();
108107
CHECK_EQ(in_dims.size(), idx_dims.size());

infini_train/src/kernels/cpu/scatter.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
namespace infini_train::kernels::cpu {
1212

13-
std::shared_ptr<Tensor> ScatterForward(const std::shared_ptr<Tensor> &values,
14-
const std::shared_ptr<Tensor> &indices,
13+
std::shared_ptr<Tensor> ScatterForward(const std::shared_ptr<Tensor> &values, const std::shared_ptr<Tensor> &indices,
1514
const std::vector<int64_t> &output_dims) {
1615
CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterForward expects int64 indices";
1716
CHECK(values->Dims() == indices->Dims());
@@ -39,8 +38,8 @@ std::shared_ptr<Tensor> ScatterForward(const std::shared_ptr<Tensor> &values,
3938
const int64_t expert_idx = idx[row * topk + selected];
4039
CHECK_GE(expert_idx, 0);
4140
CHECK_LT(expert_idx, num_experts);
42-
std::memcpy(dst + (row * num_experts + expert_idx) * elem_size,
43-
src + (row * topk + selected) * elem_size, elem_size);
41+
std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, src + (row * topk + selected) * elem_size,
42+
elem_size);
4443
}
4544
}
4645

@@ -68,8 +67,8 @@ std::shared_ptr<Tensor> ScatterBackward(const std::shared_ptr<Tensor> &grad_outp
6867
const int64_t expert_idx = idx[row * topk + selected];
6968
CHECK_GE(expert_idx, 0);
7069
CHECK_LT(expert_idx, num_experts);
71-
std::memcpy(dst + (row * topk + selected) * elem_size,
72-
src + (row * num_experts + expert_idx) * elem_size, elem_size);
70+
std::memcpy(dst + (row * topk + selected) * elem_size, src + (row * num_experts + expert_idx) * elem_size,
71+
elem_size);
7372
}
7473
}
7574

@@ -78,7 +77,7 @@ std::shared_ptr<Tensor> ScatterBackward(const std::shared_ptr<Tensor> &grad_outp
7877

7978
} // namespace infini_train::kernels::cpu
8079

81-
#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \
80+
#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \
8281
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
8382

8483
REGISTER_CPU_SCATTER_KERNEL(ScatterForward)

infini_train/src/nn/functional.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#include "infini_train/include/autograd/activations.h"
88
#include "infini_train/include/autograd/elementwise.h"
9-
#include "infini_train/include/autograd/misc.h"
109
#include "infini_train/include/autograd/reduction.h"
1110
#include "infini_train/include/autograd/softmax.h"
1211
#include "infini_train/include/autograd/transform.h"

infini_train/src/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
#include "infini_train/include/autograd/elementwise.h"
1414
#include "infini_train/include/autograd/function.h"
1515
#include "infini_train/include/autograd/function_hook.h"
16-
#include "infini_train/include/autograd/matmul.h"
1716
#include "infini_train/include/autograd/indexing.h"
17+
#include "infini_train/include/autograd/matmul.h"
1818
#include "infini_train/include/autograd/no_op.h"
1919
#include "infini_train/include/autograd/outer.h"
2020
#include "infini_train/include/autograd/reduction.h"

0 commit comments

Comments
 (0)