diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index 97267e23089a4d..8d12bc13ab1ac1 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -80,6 +80,33 @@ inline bool isComplexType(ScalarType t) { t == ScalarType::ComplexDouble); } +inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + inline bool isSignedType(ScalarType t) { #define CASE_ISSIGNED(name) \ case ScalarType::name: \ @@ -177,6 +204,57 @@ inline bool isSignedType(ScalarType t) { return false; // Unreachable, but satisfies compiler } +inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + TORCH_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +inline bool canCast(const ScalarType from, const ScalarType to) { + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + } // namespace c10 namespace at { diff --git a/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h b/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h index b51842d83fb963..7e6fd9670e1ec0 100644 --- a/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h +++ b/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h @@ -313,4 +313,25 @@ inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { return stream << toString(scalar_type); } +inline bool isQIntType(ScalarType t) { + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + +inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + } // namespace c10 diff --git a/test/cpp/compat/c10_ScalarType_test.cc b/test/cpp/compat/c10_ScalarType_test.cc index 6a3bbc9b77fff9..d85326b7726a57 100644 --- a/test/cpp/compat/c10_ScalarType_test.cc +++ b/test/cpp/compat/c10_ScalarType_test.cc @@ -151,3 +151,78 @@ TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) { EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu)); EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2)); } + +TEST(ScalarTypeTest, HelperPredicatesAndConversionsMatchPyTorchBehavior) { + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QInt8)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt8)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QInt32)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt4x2)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt2x4)); + EXPECT_FALSE(c10::isQIntType(c10::ScalarType::Float)); + + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits1x8)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits2x4)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits4x2)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits8)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits16)); + EXPECT_FALSE(c10::isBitsType(c10::ScalarType::Int)); + + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt1)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt7)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt16)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt64)); + EXPECT_FALSE(c10::isBarebonesUnsignedType(c10::ScalarType::Byte)); + EXPECT_FALSE(c10::isBarebonesUnsignedType(c10::ScalarType::Int)); + + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Byte), c10::ScalarType::QUInt8); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Char), c10::ScalarType::QInt8); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Int), c10::ScalarType::QInt32); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Float), c10::ScalarType::Float); + + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt8), c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt4x2), + c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt2x4), + c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QInt8), c10::ScalarType::Char); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QInt32), c10::ScalarType::Int); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::Float), c10::ScalarType::Float); + + EXPECT_TRUE( + c10::isUnderlying(c10::ScalarType::Byte, c10::ScalarType::QUInt8)); + EXPECT_TRUE(c10::isUnderlying(c10::ScalarType::Char, c10::ScalarType::QInt8)); + EXPECT_TRUE(c10::isUnderlying(c10::ScalarType::Int, c10::ScalarType::QInt32)); + EXPECT_FALSE( + c10::isUnderlying(c10::ScalarType::Byte, c10::ScalarType::QInt8)); + + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexHalf), + c10::ScalarType::Half); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexFloat), + c10::ScalarType::Float); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexDouble), + c10::ScalarType::Double); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::Int), c10::ScalarType::Int); + + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Half), + c10::ScalarType::ComplexHalf); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Float), + c10::ScalarType::ComplexFloat); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Double), + c10::ScalarType::ComplexDouble); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::BFloat16), + c10::ScalarType::ComplexFloat); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::ComplexFloat), + c10::ScalarType::ComplexFloat); + + EXPECT_TRUE(c10::canCast(c10::ScalarType::Int, c10::ScalarType::Long)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::Float, c10::ScalarType::Double)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::ComplexFloat, + c10::ScalarType::ComplexDouble)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::Bool, c10::ScalarType::Int)); + + EXPECT_FALSE( + c10::canCast(c10::ScalarType::ComplexFloat, c10::ScalarType::Float)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Float, c10::ScalarType::Int)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Double, c10::ScalarType::Long)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Int, c10::ScalarType::Bool)); +}