|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <cstdint> |
| 4 | +#include <tuple> |
| 5 | +#include <variant> |
| 6 | +#include <string> |
| 7 | +#include <type_traits> |
| 8 | +#include <stdexcept> |
| 9 | +#include <limits> |
| 10 | + |
| 11 | +namespace vllm |
| 12 | +{ |
| 13 | + |
| 14 | + class ScalarType |
| 15 | + { |
| 16 | + public: |
| 17 | + enum NanRepr : uint8_t |
| 18 | + { |
| 19 | + NAN_NONE = 0, |
| 20 | + NAN_IEEE_754 = 1, |
| 21 | + NAN_EXTD_RANGE_MAX_MIN = 2, |
| 22 | + |
| 23 | + NAN_REPR_ID_MAX |
| 24 | + }; |
| 25 | + |
| 26 | + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, |
| 27 | + int32_t bias, bool finite_values_only = false, |
| 28 | + NanRepr nan_repr = NAN_IEEE_754) |
| 29 | + : exponent(exponent), |
| 30 | + mantissa(mantissa), |
| 31 | + signed_(signed_), |
| 32 | + bias(bias), |
| 33 | + finite_values_only(finite_values_only), |
| 34 | + nan_repr(nan_repr) {} |
| 35 | + |
| 36 | + // ----------------------- |
| 37 | + // Integer |
| 38 | + // ----------------------- |
| 39 | + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) |
| 40 | + { |
| 41 | + return ScalarType(0, size_bits - 1, true, bias); |
| 42 | + } |
| 43 | + |
| 44 | + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) |
| 45 | + { |
| 46 | + return ScalarType(0, size_bits, false, bias); |
| 47 | + } |
| 48 | + |
| 49 | + // ----------------------- |
| 50 | + // Floating point(constexpr安全:不做检查) |
| 51 | + // ----------------------- |
| 52 | + static constexpr ScalarType float_IEEE754(uint8_t exponent, |
| 53 | + uint8_t mantissa) |
| 54 | + { |
| 55 | + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); |
| 56 | + } |
| 57 | + |
| 58 | + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, |
| 59 | + bool finite_values_only, |
| 60 | + NanRepr nan_repr) |
| 61 | + { |
| 62 | + return ScalarType(exponent, mantissa, true, 0, |
| 63 | + finite_values_only, nan_repr); |
| 64 | + } |
| 65 | + |
| 66 | + // ----------------------- |
| 67 | + // Runtime checked(可选) |
| 68 | + // ----------------------- |
| 69 | + static inline ScalarType float_checked(uint8_t exponent, |
| 70 | + uint8_t mantissa, |
| 71 | + bool finite_values_only, |
| 72 | + NanRepr nan_repr) |
| 73 | + { |
| 74 | + if (!(nan_repr < NAN_REPR_ID_MAX)) |
| 75 | + throw std::runtime_error("Invalid NanRepr"); |
| 76 | + |
| 77 | + if (!(mantissa > 0 && exponent > 0)) |
| 78 | + throw std::runtime_error("mantissa/exponent must > 0"); |
| 79 | + |
| 80 | + if (nan_repr == NAN_IEEE_754) |
| 81 | + throw std::runtime_error("use float_IEEE754"); |
| 82 | + |
| 83 | + return float_(exponent, mantissa, finite_values_only, nan_repr); |
| 84 | + } |
| 85 | + |
| 86 | + uint8_t const exponent; |
| 87 | + uint8_t const mantissa; |
| 88 | + bool const signed_; |
| 89 | + int32_t const bias; |
| 90 | + |
| 91 | + bool const finite_values_only; |
| 92 | + NanRepr const nan_repr; |
| 93 | + |
| 94 | + using Id = int64_t; |
| 95 | + |
| 96 | + private: |
| 97 | + template <typename T_> |
| 98 | + static constexpr size_t member_id_field_width() |
| 99 | + { |
| 100 | + using T = std::decay_t<T_>; |
| 101 | + return std::is_same<T, bool>::value ? 1 : sizeof(T) * 8; |
| 102 | + } |
| 103 | + |
| 104 | + template <typename Fn, typename Init, typename Member, typename... Rest> |
| 105 | + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, |
| 106 | + Rest... rest) |
| 107 | + { |
| 108 | + auto new_val = f(val, member); |
| 109 | + if constexpr (sizeof...(rest) > 0) |
| 110 | + { |
| 111 | + return reduce_members_helper(f, new_val, rest...); |
| 112 | + } |
| 113 | + else |
| 114 | + { |
| 115 | + return new_val; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + template <typename Fn, typename Init> |
| 120 | + constexpr auto reduce_members(Fn f, Init init) const |
| 121 | + { |
| 122 | + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, |
| 123 | + finite_values_only, nan_repr); |
| 124 | + } |
| 125 | + |
| 126 | + template <typename Fn, typename Init> |
| 127 | + static constexpr auto reduce_member_types(Fn f, Init init) |
| 128 | + { |
| 129 | + constexpr auto dummy = ScalarType(0, 0, false, 0, false, NAN_NONE); |
| 130 | + return dummy.reduce_members(f, init); |
| 131 | + } |
| 132 | + |
| 133 | + static constexpr auto id_size_bits() |
| 134 | + { |
| 135 | + return reduce_member_types( |
| 136 | + [](int acc, auto member) -> int |
| 137 | + { |
| 138 | + return acc + member_id_field_width<decltype(member)>(); |
| 139 | + }, |
| 140 | + 0); |
| 141 | + } |
| 142 | + |
| 143 | + public: |
| 144 | + constexpr Id id() const |
| 145 | + { |
| 146 | + static_assert(id_size_bits() <= sizeof(Id) * 8, |
| 147 | + "ScalarType id too large"); |
| 148 | + |
| 149 | + auto fn = [](std::pair<Id, uint32_t> result, auto member) |
| 150 | + { |
| 151 | + auto [id, offset] = result; |
| 152 | + constexpr auto bits = member_id_field_width<decltype(member)>(); |
| 153 | + return std::pair<Id, uint32_t>{ |
| 154 | + id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << offset, |
| 155 | + offset + bits}; |
| 156 | + }; |
| 157 | + |
| 158 | + return reduce_members(fn, std::pair<Id, uint32_t>{}).first; |
| 159 | + } |
| 160 | + |
| 161 | + static constexpr ScalarType from_id(Id id) |
| 162 | + { |
| 163 | + auto fn = [id](auto result, auto member) |
| 164 | + { |
| 165 | + using T = decltype(member); |
| 166 | + auto [tuple, offset] = result; |
| 167 | + constexpr auto bits = member_id_field_width<T>(); |
| 168 | + auto val = static_cast<T>((id >> offset) & ((uint64_t(1) << bits) - 1)); |
| 169 | + return std::pair{std::tuple_cat(tuple, std::make_tuple(val)), offset + bits}; |
| 170 | + }; |
| 171 | + |
| 172 | + auto [args, _] = |
| 173 | + reduce_member_types(fn, std::pair<std::tuple<>, int>{}); |
| 174 | + |
| 175 | + return std::apply([](auto... xs) |
| 176 | + { return ScalarType(xs...); }, args); |
| 177 | + } |
| 178 | + |
| 179 | + constexpr int64_t size_bits() const |
| 180 | + { |
| 181 | + return mantissa + exponent + (signed_ ? 1 : 0); |
| 182 | + } |
| 183 | + |
| 184 | + constexpr bool is_signed() const { return signed_; } |
| 185 | + constexpr bool is_integer() const { return exponent == 0; } |
| 186 | + constexpr bool is_floating_point() const { return exponent > 0; } |
| 187 | + |
| 188 | + constexpr bool is_ieee_754() const |
| 189 | + { |
| 190 | + return is_floating_point() && !finite_values_only && |
| 191 | + nan_repr == NAN_IEEE_754; |
| 192 | + } |
| 193 | + |
| 194 | + constexpr bool has_nans() const |
| 195 | + { |
| 196 | + return is_floating_point() && nan_repr != NAN_NONE; |
| 197 | + } |
| 198 | + |
| 199 | + constexpr bool has_infs() const |
| 200 | + { |
| 201 | + return is_floating_point() && !finite_values_only; |
| 202 | + } |
| 203 | + |
| 204 | + constexpr bool has_bias() const { return bias != 0; } |
| 205 | + |
| 206 | + std::string str() const |
| 207 | + { |
| 208 | + if (is_floating_point()) |
| 209 | + { |
| 210 | + auto ret = "float" + std::to_string(size_bits()) + "_e" + |
| 211 | + std::to_string(exponent) + "m" + std::to_string(mantissa); |
| 212 | + |
| 213 | + if (!is_ieee_754()) |
| 214 | + { |
| 215 | + if (finite_values_only) |
| 216 | + ret += "f"; |
| 217 | + if (nan_repr != NAN_NONE) |
| 218 | + ret += "n"; |
| 219 | + } |
| 220 | + return ret; |
| 221 | + } |
| 222 | + else |
| 223 | + { |
| 224 | + auto ret = (signed_ ? "int" : "uint") + |
| 225 | + std::to_string(size_bits()); |
| 226 | + if (has_bias()) |
| 227 | + ret += "b" + std::to_string(bias); |
| 228 | + return ret; |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + constexpr bool operator==(ScalarType const &other) const |
| 233 | + { |
| 234 | + return mantissa == other.mantissa && |
| 235 | + exponent == other.exponent && |
| 236 | + bias == other.bias && |
| 237 | + signed_ == other.signed_ && |
| 238 | + finite_values_only == other.finite_values_only && |
| 239 | + nan_repr == other.nan_repr; |
| 240 | + } |
| 241 | + }; |
| 242 | + |
| 243 | + using ScalarTypeId = ScalarType::Id; |
| 244 | + |
| 245 | + // ----------------------- |
| 246 | + // 原始常量(完全保留) |
| 247 | + // ----------------------- |
| 248 | + |
| 249 | + static inline constexpr auto kS4 = ScalarType::int_(4); |
| 250 | + static inline constexpr auto kU4 = ScalarType::uint(4); |
| 251 | + static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); |
| 252 | + static inline constexpr auto kS8 = ScalarType::int_(8); |
| 253 | + static inline constexpr auto kU8 = ScalarType::uint(8); |
| 254 | + static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); |
| 255 | + |
| 256 | + static inline constexpr auto kFE2M1f = |
| 257 | + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); |
| 258 | + static inline constexpr auto kFE3M2f = |
| 259 | + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); |
| 260 | + static inline constexpr auto kFE4M3fn = |
| 261 | + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); |
| 262 | + static inline constexpr auto kFE8M0fnu = |
| 263 | + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); |
| 264 | + static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); |
| 265 | + static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); |
| 266 | + static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); |
| 267 | + |
| 268 | + // 🔥 关键:alias(不能丢!) |
| 269 | + |
| 270 | + static inline constexpr auto kInt4 = kS4; |
| 271 | + static inline constexpr auto kUint4 = kU4; |
| 272 | + static inline constexpr auto kUint4b8 = kU4B8; |
| 273 | + static inline constexpr auto kInt8 = kS8; |
| 274 | + static inline constexpr auto kUint8 = kU8; |
| 275 | + static inline constexpr auto kUint8b128 = kU8B128; |
| 276 | + |
| 277 | + static inline constexpr auto kFloat4_e2m1f = kFE2M1f; |
| 278 | + static inline constexpr auto kFloat6_e3m2f = kFE3M2f; |
| 279 | + static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; |
| 280 | + static inline constexpr auto kFloat8_e5m2 = kFE5M2; |
| 281 | + static inline constexpr auto kFloat16_e8m7 = kFE8M7; |
| 282 | + static inline constexpr auto kFloat16_e5m10 = kFE5M10; |
| 283 | + |
| 284 | + // ⭐ 这些就是你报错缺失的 |
| 285 | + static inline constexpr auto kHalf = kFE5M10; |
| 286 | + static inline constexpr auto kFloat16 = kHalf; |
| 287 | + static inline constexpr auto kBFloat16 = kFE8M7; |
| 288 | + |
| 289 | + static inline constexpr auto kFloat16Id = kFloat16.id(); |
| 290 | + |
| 291 | +} // namespace vllm |
0 commit comments