Skip to content

Commit e63cb38

Browse files
committed
refactor: remove old promotion codes implemented by WidestType_t
1 parent e791b1d commit e63cb38

File tree

1 file changed

+0
-102
lines changed

1 file changed

+0
-102
lines changed

infini_train/include/datatype.h

Lines changed: 0 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <cmath>
55
#include <cstdint>
66
#include <string>
7-
#include <type_traits>
87
#include <unordered_map>
98

109
namespace infini_train {
@@ -303,110 +302,9 @@ template <> struct DataTypeMap<BF16> {
303302
static constexpr DataType value = DataType::kBFLOAT16;
304303
};
305304

306-
// -----------------------------------------------------------------------------
307-
// Type traits extensions (framework fallback scalar semantics)
308-
// -----------------------------------------------------------------------------
309-
template <typename T> struct is_floating_point_ext : std::is_floating_point<T> {};
310-
311-
template <typename T> struct is_arithmetic_ext : std::is_arithmetic<T> {};
312-
313-
template <> struct is_floating_point_ext<BF16> : std::true_type {};
314-
template <> struct is_arithmetic_ext<BF16> : std::true_type {};
315-
316-
template <> struct is_floating_point_ext<FP16> : std::true_type {};
317-
template <> struct is_arithmetic_ext<FP16> : std::true_type {};
318-
319-
// -----------------------------------------------------------------------------
320-
// Promotion helpers (framework-level WidestType)
321-
// -----------------------------------------------------------------------------
322-
namespace detail {
323-
324-
template <typename T1, typename T2> struct LargerType {
325-
static constexpr size_t size1 = sizeof(T1);
326-
static constexpr size_t size2 = sizeof(T2);
327-
using type = std::conditional_t<(size1 >= size2), T1, T2>;
328-
};
329-
330-
template <> struct LargerType<BF16, FP16> {
331-
using type = float;
332-
};
333-
334-
template <> struct LargerType<FP16, BF16> {
335-
using type = float;
336-
};
337-
338-
/**
339-
* @brief Finds the first type in a parameter pack that satisfies the given predicate.
340-
* If no type matches, returns the last type in the pack (base case).
341-
*/
342-
template <template <typename> class Predicate, typename... Ts> struct FirstMatchingType;
343-
344-
template <template <typename> class Predicate, typename T> struct FirstMatchingType<Predicate, T> {
345-
using type = T;
346-
};
347-
348-
template <template <typename> class Predicate, typename T, typename... Ts>
349-
struct FirstMatchingType<Predicate, T, Ts...> {
350-
using type = std::conditional_t<Predicate<T>::value, T, typename FirstMatchingType<Predicate, Ts...>::type>;
351-
};
352-
353-
/**
354-
* @brief Recursively finds the widest type among those that satisfy a predicate.
355-
* Types not satisfying the predicate are ignored and don't affect the current maximum.
356-
*/
357-
template <template <typename> class Predicate, typename CurrentMax, typename... Ts> struct WidestTypeImpl;
358-
359-
template <template <typename> class Predicate, typename CurrentMax> struct WidestTypeImpl<Predicate, CurrentMax> {
360-
using type = CurrentMax;
361-
};
362-
363-
template <template <typename> class Predicate, typename CurrentMax, typename T, typename... Ts>
364-
struct WidestTypeImpl<Predicate, CurrentMax, T, Ts...> {
365-
using new_max = std::conditional_t<Predicate<T>::value, typename LargerType<CurrentMax, T>::type, CurrentMax>;
366-
using type = typename WidestTypeImpl<Predicate, new_max, Ts...>::type;
367-
};
368-
369-
template <template <typename> class Predicate, typename... Ts> struct MaxTypeBySizeWithPredicate {
370-
using first = typename FirstMatchingType<Predicate, Ts...>::type;
371-
using type = typename WidestTypeImpl<Predicate, first, Ts...>::type;
372-
};
373-
374-
} // namespace detail
375-
376-
/**
377-
* @brief Finds the widest/largest type according to a PyTorch-like dtype promotion rule among a pack of arithmetic
378-
* types.
379-
*
380-
* - If floating-point types are present, selects the largest floating-point type;
381-
* - Otherwise selects the largest integral type.
382-
* - If multiple integral types have the same size, precedence follows the list order.
383-
*
384-
* Note:
385-
* - FP16/BF16 are treated as floating-point.
386-
* - Mixed FP16 and BF16 promotes to float (32-bit).
387-
*/
388-
template <typename... Ts> struct WidestType {
389-
static_assert(sizeof...(Ts) > 0, "At least one type is required");
390-
static_assert((is_arithmetic_ext<Ts>::value && ...),
391-
"All types must be arithmetic or framework floating-point types (FP16/BF16)");
392-
393-
static constexpr bool has_float = (is_floating_point_ext<Ts>::value || ...);
394-
395-
using type =
396-
typename std::conditional_t<has_float, detail::MaxTypeBySizeWithPredicate<is_floating_point_ext, Ts...>,
397-
detail::MaxTypeBySizeWithPredicate<std::is_integral, Ts...>>::type;
398-
};
399-
400-
// Convenience alias
401-
template <typename... Ts> using WidestType_t = typename WidestType<Ts...>::type;
402-
403305
// =============================================================================
404306
// DataType-level promotion (pure enum → enum, no concrete/backend types)
405307
// =============================================================================
406-
// These facilities replace `DataTypeMap_v<WidestType_t<Ta, Tb>>` in CUDA
407-
// kernels, so that backend kernels never need to know about __half /
408-
// __nv_bfloat16 at promotion time.
409-
//
410308
// Rules (priority order):
411309
// 1. FP16 + BF16 → FLOAT32 (neither is a lossless superset of the other)
412310
// 2. Any float dominates any integer → keep the float type

0 commit comments

Comments
 (0)