Skip to content

Commit ebd9a98

Browse files
committed
feat: introduce Scalar abstraction to support multiple scalar types
1 parent 53042ef commit ebd9a98

File tree

6 files changed

+67
-12
lines changed

6 files changed

+67
-12
lines changed

infini_train/include/dtype_dispatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \
1919
LOG_LOC(FATAL, std::string(CONTEXT_IDENTIFIER) \
20-
+ ": Unsupported data type: " + kDataTypeToDesc.at(static_cast<infini_train::DataType>(dtype)))
20+
+ ": Unsupported data type: " + kDataTypeToDesc.at(static_cast<infini_train::DataType>(DTYPE)))
2121

2222
// Helper macros to count the number of arguments
2323
#define PP_NARG(...) PP_NARG_(__VA_ARGS__, PP_RSEQ_N())

infini_train/include/scalar.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <type_traits>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/common/cpu/common_cpu.h"
9+
10+
namespace infini_train {
11+
12+
struct Scalar {
13+
enum class Kind : uint8_t { kBool, kDouble, kInt64, kUInt64 };
14+
15+
Scalar() : kind(Kind::kInt64), i(0) {}
16+
Scalar(bool v) : kind(Kind::kBool), u(v ? 1 : 0) {}
17+
18+
template <typename T, typename std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
19+
Scalar(T v) : kind(Kind::kDouble), d(static_cast<double>(v)) {}
20+
21+
template <typename T,
22+
typename std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T> && !std::is_same_v<T, bool>, int>
23+
= 0>
24+
Scalar(T v) : kind(Kind::kInt64), i(static_cast<int64_t>(v)) {}
25+
26+
template <typename T,
27+
typename std::enable_if_t<std::is_integral_v<T> && std::is_unsigned_v<T> && !std::is_same_v<T, bool>, int>
28+
= 0>
29+
Scalar(T v) : kind(Kind::kUInt64), u(static_cast<uint64_t>(v)) {}
30+
31+
Scalar(FP16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
32+
Scalar(BF16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
33+
34+
template <typename T> T to() const {
35+
switch (kind) {
36+
case Kind::kBool:
37+
return common::cpu::Cast<T>(u != 0);
38+
case Kind::kDouble:
39+
return common::cpu::Cast<T>(d);
40+
case Kind::kInt64:
41+
return common::cpu::Cast<T>(i);
42+
case Kind::kUInt64:
43+
return common::cpu::Cast<T>(u);
44+
default:
45+
LOG(FATAL) << "Unknown scalar kind";
46+
}
47+
48+
std::abort();
49+
}
50+
51+
Kind kind;
52+
union {
53+
double d;
54+
int64_t i;
55+
uint64_t u;
56+
};
57+
};
58+
59+
} // namespace infini_train

infini_train/include/tensor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "infini_train/include/datatype.h"
1414
#include "infini_train/include/device.h"
15+
#include "infini_train/include/scalar.h"
1516

1617
namespace infini_train {
1718
namespace autograd {
@@ -78,8 +79,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
7879
size_t NumElements() const;
7980
DataType Dtype() const;
8081

81-
// Fill tensor with a scalar value (accepts double, automatically converts to tensor's dtype)
82-
void Fill(double value);
82+
void Fill(Scalar value);
8383

8484
Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> EigenMatrix();
8585
Eigen::Map<Eigen::Matrix<float, 1, Eigen::Dynamic, Eigen::RowMajor>> EigenVector();

infini_train/src/kernels/cpu/fill.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
#include "glog/logging.h"
22

3-
#include "infini_train/include/common/cpu/common_cpu.h"
43
#include "infini_train/include/dispatcher.h"
54
#include "infini_train/include/dtype_dispatch.h"
65
#include "infini_train/include/tensor.h"
76

87
#include "infini_train/src/core/runtime/cpu/cpu_dispatch.h"
98

109
namespace infini_train::kernels::cpu {
11-
void Fill(std::shared_ptr<Tensor> tensor, double value) {
10+
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
1211
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
1312
tensor->Dtype(),
1413
[=]<typename T>() {
1514
auto data = reinterpret_cast<T *>(tensor->DataPtr());
16-
T casted_value = common::cpu::Cast<T>(value);
15+
const T casted_value = scalar.to<T>();
1716
std::fill(data, data + tensor->NumElements(), casted_value);
1817
},
1918
"CPU Fill");

infini_train/src/kernels/cuda/fill.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <cstddef>
22
#include <memory>
33

4-
#include "infini_train/include/common/cpu/common_cpu.h"
54
#include "infini_train/include/core/runtime/device_guard.h"
65
#include "infini_train/include/device.h"
76
#include "infini_train/include/dispatcher.h"
@@ -20,7 +19,7 @@ template <typename T> __global__ void FillKernel(T *data, T value, size_t size)
2019
}
2120

2221
// TODO(dcj): refactor Fill kernel with elementwise template
23-
void Fill(std::shared_ptr<Tensor> tensor, double value) {
22+
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
2423
const int num_tokens = tensor->NumElements();
2524
const int threads_per_block = 256;
2625
const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block;
@@ -32,7 +31,7 @@ void Fill(std::shared_ptr<Tensor> tensor, double value) {
3231
core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
3332
tensor->Dtype(),
3433
[=]<typename T>() {
35-
T casted_value = common::cpu::Cast<T>(value);
34+
const T casted_value = scalar.to<T>();
3635
FillKernel<T><<<num_blocks, threads_per_block, 0, cuda_stream>>>(static_cast<T *>(tensor->DataPtr()),
3736
casted_value, tensor->NumElements());
3837
},

infini_train/src/tensor.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
#include "infini_train/include/autograd/outer.h"
1919
#include "infini_train/include/autograd/reduction.h"
2020
#include "infini_train/include/autograd/transform.h"
21-
#include "infini_train/include/common/cpu/common_cpu.h"
2221
#include "infini_train/include/core/runtime/device_guard.h"
2322
#include "infini_train/include/datatype.h"
2423
#include "infini_train/include/device.h"
2524
#include "infini_train/include/dispatcher.h"
26-
#include "infini_train/include/dtype_dispatch.h"
2725
#include "infini_train/include/nn/init.h"
2826

2927
namespace infini_train {
@@ -104,7 +102,7 @@ size_t Tensor::NumElements() const { return num_elements_; }
104102

105103
DataType Tensor::Dtype() const { return dtype_; }
106104

107-
void Tensor::Fill(double value) {
105+
void Tensor::Fill(Scalar value) {
108106
auto device = GetDevice();
109107
core::DeviceGuard guard(device);
110108
auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Fill"});

0 commit comments

Comments
 (0)