|
4 | 4 | #include <cmath> |
5 | 5 | #include <cstdint> |
6 | 6 | #include <string> |
7 | | -#include <type_traits> |
8 | 7 | #include <unordered_map> |
9 | 8 |
|
10 | 9 | namespace infini_train { |
@@ -303,110 +302,9 @@ template <> struct DataTypeMap<BF16> { |
303 | 302 | static constexpr DataType value = DataType::kBFLOAT16; |
304 | 303 | }; |
305 | 304 |
|
306 | | -// ----------------------------------------------------------------------------- |
307 | | -// Type traits extensions (framework fallback scalar semantics) |
308 | | -// ----------------------------------------------------------------------------- |
309 | | -template <typename T> struct is_floating_point_ext : std::is_floating_point<T> {}; |
310 | | - |
311 | | -template <typename T> struct is_arithmetic_ext : std::is_arithmetic<T> {}; |
312 | | - |
313 | | -template <> struct is_floating_point_ext<BF16> : std::true_type {}; |
314 | | -template <> struct is_arithmetic_ext<BF16> : std::true_type {}; |
315 | | - |
316 | | -template <> struct is_floating_point_ext<FP16> : std::true_type {}; |
317 | | -template <> struct is_arithmetic_ext<FP16> : std::true_type {}; |
318 | | - |
319 | | -// ----------------------------------------------------------------------------- |
320 | | -// Promotion helpers (framework-level WidestType) |
321 | | -// ----------------------------------------------------------------------------- |
322 | | -namespace detail { |
323 | | - |
324 | | -template <typename T1, typename T2> struct LargerType { |
325 | | - static constexpr size_t size1 = sizeof(T1); |
326 | | - static constexpr size_t size2 = sizeof(T2); |
327 | | - using type = std::conditional_t<(size1 >= size2), T1, T2>; |
328 | | -}; |
329 | | - |
330 | | -template <> struct LargerType<BF16, FP16> { |
331 | | - using type = float; |
332 | | -}; |
333 | | - |
334 | | -template <> struct LargerType<FP16, BF16> { |
335 | | - using type = float; |
336 | | -}; |
337 | | - |
338 | | -/** |
339 | | - * @brief Finds the first type in a parameter pack that satisfies the given predicate. |
340 | | - * If no type matches, returns the last type in the pack (base case). |
341 | | - */ |
342 | | -template <template <typename> class Predicate, typename... Ts> struct FirstMatchingType; |
343 | | - |
344 | | -template <template <typename> class Predicate, typename T> struct FirstMatchingType<Predicate, T> { |
345 | | - using type = T; |
346 | | -}; |
347 | | - |
348 | | -template <template <typename> class Predicate, typename T, typename... Ts> |
349 | | -struct FirstMatchingType<Predicate, T, Ts...> { |
350 | | - using type = std::conditional_t<Predicate<T>::value, T, typename FirstMatchingType<Predicate, Ts...>::type>; |
351 | | -}; |
352 | | - |
353 | | -/** |
354 | | - * @brief Recursively finds the widest type among those that satisfy a predicate. |
355 | | - * Types not satisfying the predicate are ignored and don't affect the current maximum. |
356 | | - */ |
357 | | -template <template <typename> class Predicate, typename CurrentMax, typename... Ts> struct WidestTypeImpl; |
358 | | - |
359 | | -template <template <typename> class Predicate, typename CurrentMax> struct WidestTypeImpl<Predicate, CurrentMax> { |
360 | | - using type = CurrentMax; |
361 | | -}; |
362 | | - |
363 | | -template <template <typename> class Predicate, typename CurrentMax, typename T, typename... Ts> |
364 | | -struct WidestTypeImpl<Predicate, CurrentMax, T, Ts...> { |
365 | | - using new_max = std::conditional_t<Predicate<T>::value, typename LargerType<CurrentMax, T>::type, CurrentMax>; |
366 | | - using type = typename WidestTypeImpl<Predicate, new_max, Ts...>::type; |
367 | | -}; |
368 | | - |
369 | | -template <template <typename> class Predicate, typename... Ts> struct MaxTypeBySizeWithPredicate { |
370 | | - using first = typename FirstMatchingType<Predicate, Ts...>::type; |
371 | | - using type = typename WidestTypeImpl<Predicate, first, Ts...>::type; |
372 | | -}; |
373 | | - |
374 | | -} // namespace detail |
375 | | - |
376 | | -/** |
377 | | - * @brief Finds the widest/largest type according to a PyTorch-like dtype promotion rule among a pack of arithmetic |
378 | | - * types. |
379 | | - * |
380 | | - * - If floating-point types are present, selects the largest floating-point type; |
381 | | - * - Otherwise selects the largest integral type. |
382 | | - * - If multiple integral types have the same size, precedence follows the list order. |
383 | | - * |
384 | | - * Note: |
385 | | - * - FP16/BF16 are treated as floating-point. |
386 | | - * - Mixed FP16 and BF16 promotes to float (32-bit). |
387 | | - */ |
388 | | -template <typename... Ts> struct WidestType { |
389 | | - static_assert(sizeof...(Ts) > 0, "At least one type is required"); |
390 | | - static_assert((is_arithmetic_ext<Ts>::value && ...), |
391 | | - "All types must be arithmetic or framework floating-point types (FP16/BF16)"); |
392 | | - |
393 | | - static constexpr bool has_float = (is_floating_point_ext<Ts>::value || ...); |
394 | | - |
395 | | - using type = |
396 | | - typename std::conditional_t<has_float, detail::MaxTypeBySizeWithPredicate<is_floating_point_ext, Ts...>, |
397 | | - detail::MaxTypeBySizeWithPredicate<std::is_integral, Ts...>>::type; |
398 | | -}; |
399 | | - |
400 | | -// Convenience alias |
401 | | -template <typename... Ts> using WidestType_t = typename WidestType<Ts...>::type; |
402 | | - |
403 | 305 | // ============================================================================= |
404 | 306 | // DataType-level promotion (pure enum → enum, no concrete/backend types) |
405 | 307 | // ============================================================================= |
406 | | -// These facilities replace `DataTypeMap_v<WidestType_t<Ta, Tb>>` in CUDA |
407 | | -// kernels, so that backend kernels never need to know about __half / |
408 | | -// __nv_bfloat16 at promotion time. |
409 | | -// |
410 | 308 | // Rules (priority order): |
411 | 309 | // 1. FP16 + BF16 → FLOAT32 (neither is a lossless superset of the other) |
412 | 310 | // 2. Any float dominates any integer → keep the float type |
|
0 commit comments