Skip to content

Commit c4d8812

Browse files
committed
tmp
1 parent d278062 commit c4d8812

48 files changed

Lines changed: 1071 additions & 679 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

infini_train/include/common/common.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,21 @@
77

88
#include "infini_train/include/datatype.h"
99

10+
/**
11+
* General Utility Macros
12+
*/
13+
#define EXPAND(X) X
14+
// This macro lets you pass an arbitrary expression that may contain internal
15+
// commas to another macro without having the commas causing the expression
16+
// to be interpreted as being multiple arguments
17+
// Basically an alternative for __VA_OPTS__ before C++20
18+
// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h
19+
#define WRAP(...) __VA_ARGS__
20+
#define CAT(a, b) CAT_(a, b)
21+
#define CAT_(a, b) a##b
22+
1023
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
1124
#define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__
12-
#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \
13-
LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \
14-
+ kDataTypeToDesc.at(static_cast<infini_train::DataType>(dtype))))
1525

1626
inline std::vector<int64_t> ComputeStrides(const std::vector<int64_t> &dims) {
1727
std::vector<int64_t> strides(dims.size(), 1);

infini_train/include/common/cpu/common_cpu.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,41 @@
33
#include <type_traits>
44
#include <utility>
55

6+
#include "infini_train/include/datatype.h"
7+
68
namespace infini_train::common::cpu {
9+
10+
namespace detail {
11+
12+
// FP16/BF16 don't support implicit conversion, so we route through float.
13+
template <typename DST, typename SRC> DST CastImpl(SRC &&x) {
14+
using SrcBase = std::remove_cvref_t<SRC>;
15+
if constexpr (std::is_same_v<DST, SrcBase>) {
16+
return x;
17+
} else if constexpr (std::is_same_v<DST, FP16> || std::is_same_v<DST, BF16>) {
18+
// Destination is a framework 16-bit type: convert via float
19+
return DST(static_cast<float>(std::forward<SRC>(x)));
20+
} else if constexpr (std::is_same_v<SrcBase, FP16> || std::is_same_v<SrcBase, BF16>) {
21+
// Source is a framework 16-bit type: widen to float first
22+
return static_cast<DST>(static_cast<float>(x));
23+
} else {
24+
return static_cast<DST>(std::forward<SRC>(x));
25+
}
26+
}
27+
28+
} // namespace detail
29+
730
/**
8-
* Converts a value between arbitrary types. This offers perfect
9-
* forwarding which preserves value categories (lvalues/rvalues)
31+
* Converts a value between arbitrary types, including framework FP16/BF16.
1032
*
11-
* @tparam DST Destination type (deduced)
33+
* @tparam DST Destination type
1234
* @tparam SRC Source type (deduced)
13-
* @param x Input value (preserves const/volatile and value category)
35+
* @param x Input value
1436
* @return Value converted to DST type
1537
*/
1638
template <typename DST, typename SRC> DST Cast(SRC &&x) {
1739
static_assert(!std::is_reference_v<DST>, "Cast cannot return reference types");
18-
19-
// TODO(lzm): add cpu-version fp16 and bf16
20-
return (DST)(std::forward<SRC>(x));
40+
return detail::CastImpl<DST>(std::forward<SRC>(x));
2141
}
42+
2243
} // namespace infini_train::common::cpu

infini_train/include/core/device_guard.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class DeviceGuardImpl {
6767
// Device management
6868
// ----------------------------------------------------------------------
6969

70+
// FIXME(dcj): impl should only bind with device type
7071
virtual Device GetDevice() const = 0;
7172

7273
virtual void SetDevice(Device device) const;
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <cstring>
5+
#include <type_traits>
6+
7+
#include "infini_train/include/datatype.h"
8+
#include "infini_train/include/device.h"
9+
10+
namespace infini_train::core {
11+
12+
/**
13+
* Dtype bridge
14+
*
15+
* Purpose:
16+
* - Define the backend-agnostic mapping protocol from framework scalar types
17+
* (e.g. infini_train::FP16/BF16) to backend-native scalar types
18+
* (e.g. __half / __nv_bfloat16 / vendor fp16/bf16 types).
19+
*
20+
* Design notes:
21+
* - This header MUST remain backend-agnostic.
22+
* - Framework public code should only depend on infini_train::FP16/BF16.
23+
* - Backend code provides specializations of NativeScalar<Dev, Scalar>.
24+
* - ScalarConvert provides optional value-level conversion helpers.
25+
*/
26+
27+
// -----------------------------------------------------------------------------
28+
// NativeScalar: framework scalar -> backend native scalar mapping
29+
// -----------------------------------------------------------------------------
30+
// Primary template intentionally undefined.
31+
// Each backend specializes the scalar types it supports.
32+
template <Device::DeviceType Dev, typename Scalar> struct NativeScalar;
33+
34+
template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type;
35+
36+
// Optional convenience alias for CUDA call sites.
37+
// Keep only one copy here; backend files should NOT redefine it.
38+
template <typename Scalar> using NativeScalarCUDA_t = NativeScalar_t<Device::DeviceType::kCUDA, Scalar>;
39+
40+
// -----------------------------------------------------------------------------
41+
// Bitcast utilities
42+
// -----------------------------------------------------------------------------
43+
template <typename To, typename From> inline To Bitcast(const From &from) noexcept {
44+
static_assert(sizeof(To) == sizeof(From), "Bitcast requires same size");
45+
static_assert(std::is_trivially_copyable_v<To>, "Bitcast To must be trivially copyable");
46+
static_assert(std::is_trivially_copyable_v<From>, "Bitcast From must be trivially copyable");
47+
48+
To to{};
49+
std::memcpy(&to, &from, sizeof(To));
50+
return to;
51+
}
52+
53+
// -----------------------------------------------------------------------------
54+
// HasNativeScalar: detect whether a NativeScalar specialization exists
55+
// -----------------------------------------------------------------------------
56+
template <Device::DeviceType Dev, typename Scalar, typename = void> struct HasNativeScalar : std::false_type {};
57+
58+
template <Device::DeviceType Dev, typename Scalar>
59+
struct HasNativeScalar<Dev, Scalar, std::void_t<typename NativeScalar<Dev, Scalar>::type>> : std::true_type {};
60+
61+
template <Device::DeviceType Dev, typename Scalar>
62+
inline constexpr bool HasNativeScalar_v = HasNativeScalar<Dev, Scalar>::value;
63+
64+
// -----------------------------------------------------------------------------
65+
// ScalarConvert: framework scalar <-> backend native scalar conversion glue
66+
// -----------------------------------------------------------------------------
67+
// Primary template intentionally undefined by default.
68+
// Backends may specialize this if simple bitcast is insufficient.
69+
template <Device::DeviceType Dev, typename Scalar, typename Enable = void> struct ScalarConvert;
70+
71+
// Default FP16 conversion: preserve raw 16-bit bit pattern.
72+
template <Device::DeviceType Dev> struct ScalarConvert<Dev, infini_train::FP16, void> {
73+
static_assert(HasNativeScalar_v<Dev, infini_train::FP16>,
74+
"Missing NativeScalar specialization for FP16 on this backend");
75+
76+
using Native = NativeScalar_t<Dev, infini_train::FP16>;
77+
78+
static inline Native ToNative(infini_train::FP16 v) noexcept {
79+
static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit");
80+
return Bitcast<Native>(v.x);
81+
}
82+
83+
static inline infini_train::FP16 FromNative(Native v) noexcept {
84+
infini_train::FP16 out{};
85+
static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit");
86+
out.x = Bitcast<uint16_t>(v);
87+
return out;
88+
}
89+
};
90+
91+
// Default BF16 conversion: preserve raw 16-bit bit pattern.
92+
template <Device::DeviceType Dev> struct ScalarConvert<Dev, infini_train::BF16, void> {
93+
static_assert(HasNativeScalar_v<Dev, infini_train::BF16>,
94+
"Missing NativeScalar specialization for BF16 on this backend");
95+
96+
using Native = NativeScalar_t<Dev, infini_train::BF16>;
97+
98+
static inline Native ToNative(infini_train::BF16 v) noexcept {
99+
static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit");
100+
return Bitcast<Native>(v.x);
101+
}
102+
103+
static inline infini_train::BF16 FromNative(Native v) noexcept {
104+
infini_train::BF16 out{};
105+
static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit");
106+
out.x = Bitcast<uint16_t>(v);
107+
return out;
108+
}
109+
};
110+
111+
// -----------------------------------------------------------------------------
112+
// Convenience wrappers
113+
// -----------------------------------------------------------------------------
114+
template <Device::DeviceType Dev, typename Scalar> inline NativeScalar_t<Dev, Scalar> ToNative(Scalar v) noexcept {
115+
return ScalarConvert<Dev, Scalar>::ToNative(v);
116+
}
117+
118+
template <Device::DeviceType Dev, typename Scalar> inline Scalar FromNative(NativeScalar_t<Dev, Scalar> v) noexcept {
119+
return ScalarConvert<Dev, Scalar>::FromNative(v);
120+
}
121+
122+
} // namespace infini_train::core

0 commit comments

Comments
 (0)