From 10c37488d26135e027502a686c76b55bc594df58 Mon Sep 17 00:00:00 2001 From: youge325 Date: Fri, 17 Apr 2026 23:08:04 +0800 Subject: [PATCH 1/3] align ScalarType related APIs --- .../api/include/compat/c10/core/ScalarType.h | 78 +++++++++++++++++++ .../compat/torch/headeronly/core/ScalarType.h | 21 +++++ test/cpp/compat/c10_ScalarType_test.cc | 75 ++++++++++++++++++ 3 files changed, 174 insertions(+) diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index 97267e23089a4d..8d12bc13ab1ac1 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -80,6 +80,33 @@ inline bool isComplexType(ScalarType t) { t == ScalarType::ComplexDouble); } +inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + inline bool isSignedType(ScalarType t) { #define CASE_ISSIGNED(name) \ case ScalarType::name: \ @@ -177,6 +204,57 @@ inline bool isSignedType(ScalarType t) { return false; // Unreachable, but satisfies compiler } +inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + TORCH_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +inline bool canCast(const ScalarType from, const ScalarType to) { + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + } // namespace c10 namespace at { diff --git a/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h b/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h index b51842d83fb963..7e6fd9670e1ec0 100644 --- a/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h +++ b/paddle/phi/api/include/compat/torch/headeronly/core/ScalarType.h @@ -313,4 +313,25 @@ inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { return stream << toString(scalar_type); } +inline bool isQIntType(ScalarType t) { + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + +inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + } // namespace c10 diff --git a/test/cpp/compat/c10_ScalarType_test.cc b/test/cpp/compat/c10_ScalarType_test.cc index 6a3bbc9b77fff9..d85326b7726a57 100644 --- a/test/cpp/compat/c10_ScalarType_test.cc +++ b/test/cpp/compat/c10_ScalarType_test.cc @@ -151,3 +151,78 @@ TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) { EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu)); EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2)); } + +TEST(ScalarTypeTest, HelperPredicatesAndConversionsMatchPyTorchBehavior) { + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QInt8)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt8)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QInt32)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt4x2)); + EXPECT_TRUE(c10::isQIntType(c10::ScalarType::QUInt2x4)); + EXPECT_FALSE(c10::isQIntType(c10::ScalarType::Float)); + + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits1x8)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits2x4)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits4x2)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits8)); + EXPECT_TRUE(c10::isBitsType(c10::ScalarType::Bits16)); + EXPECT_FALSE(c10::isBitsType(c10::ScalarType::Int)); + + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt1)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt7)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt16)); + EXPECT_TRUE(c10::isBarebonesUnsignedType(c10::ScalarType::UInt64)); + EXPECT_FALSE(c10::isBarebonesUnsignedType(c10::ScalarType::Byte)); + EXPECT_FALSE(c10::isBarebonesUnsignedType(c10::ScalarType::Int)); + + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Byte), c10::ScalarType::QUInt8); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Char), c10::ScalarType::QInt8); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Int), c10::ScalarType::QInt32); + EXPECT_EQ(c10::toQIntType(c10::ScalarType::Float), c10::ScalarType::Float); + + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt8), c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt4x2), + c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QUInt2x4), + c10::ScalarType::Byte); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QInt8), c10::ScalarType::Char); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::QInt32), c10::ScalarType::Int); + EXPECT_EQ(c10::toUnderlying(c10::ScalarType::Float), c10::ScalarType::Float); + + EXPECT_TRUE( + c10::isUnderlying(c10::ScalarType::Byte, c10::ScalarType::QUInt8)); + EXPECT_TRUE(c10::isUnderlying(c10::ScalarType::Char, c10::ScalarType::QInt8)); + EXPECT_TRUE(c10::isUnderlying(c10::ScalarType::Int, c10::ScalarType::QInt32)); + EXPECT_FALSE( + c10::isUnderlying(c10::ScalarType::Byte, c10::ScalarType::QInt8)); + + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexHalf), + c10::ScalarType::Half); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexFloat), + c10::ScalarType::Float); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::ComplexDouble), + c10::ScalarType::Double); + EXPECT_EQ(c10::toRealValueType(c10::ScalarType::Int), c10::ScalarType::Int); + + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Half), + c10::ScalarType::ComplexHalf); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Float), + c10::ScalarType::ComplexFloat); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::Double), + c10::ScalarType::ComplexDouble); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::BFloat16), + c10::ScalarType::ComplexFloat); + EXPECT_EQ(c10::toComplexType(c10::ScalarType::ComplexFloat), + c10::ScalarType::ComplexFloat); + + EXPECT_TRUE(c10::canCast(c10::ScalarType::Int, c10::ScalarType::Long)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::Float, c10::ScalarType::Double)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::ComplexFloat, + c10::ScalarType::ComplexDouble)); + EXPECT_TRUE(c10::canCast(c10::ScalarType::Bool, c10::ScalarType::Int)); + + EXPECT_FALSE( + c10::canCast(c10::ScalarType::ComplexFloat, c10::ScalarType::Float)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Float, c10::ScalarType::Int)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Double, c10::ScalarType::Long)); + EXPECT_FALSE(c10::canCast(c10::ScalarType::Int, c10::ScalarType::Bool)); +} From 586c99ebd628b7a3c1f134b19f79474e2ad0c1ea Mon Sep 17 00:00:00 2001 From: youge325 Date: Sat, 25 Apr 2026 17:20:49 +0800 Subject: [PATCH 2/3] Add ABI symbol compatibility check --- ci/static_check.sh | 17 ++ tools/check_abi_compatibility.py | 374 ++++++++++++++++++++++++++ tools/test_check_abi_compatibility.py | 188 +++++++++++++ 3 files changed, 579 insertions(+) create mode 100644 tools/check_abi_compatibility.py create mode 100644 tools/test_check_abi_compatibility.py diff --git a/ci/static_check.sh b/ci/static_check.sh index 9682a6ae48da47..98cb60220995c5 100644 --- a/ci/static_check.sh +++ b/ci/static_check.sh @@ -149,6 +149,21 @@ function exec_samplecode_checking() { fi } +function exec_abi_compatibility_check() { + if [ "$(uname -s)" != "Linux" ]; then + echo "Skip ABI compatibility check on non-Linux platform." + return + fi + + python ${PADDLE_ROOT}/tools/check_abi_compatibility.py \ + --base-wheel "${PADDLE_ROOT}/build/dev_whl/*.whl" \ + --pr-wheel "${PADDLE_ROOT}/build/pr_whl/*.whl" + abi_check_error=$? + if [ "$abi_check_error" != "0" ]; then + exit $abi_check_error + fi +} + export PATH=/usr/local/python3.10.0/bin:/usr/local/python3.10.0/include:/usr/local/bin:${PATH} echo "export PATH=${PATH}" >> ~/.bashrc export LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:$LD_LIBRARY_PATH @@ -158,6 +173,8 @@ ln -sf $(which python${PY_VERSION}) /usr/bin/python ln -sf $(which pip${PY_VERSION}) /usr/local/bin/pip mkdir -p /home/data/cfs/.ccache/static-check +exec_abi_compatibility_check + pip config set global.cache-dir "/home/data/cfs/.cache/pip" pip install --upgrade pip 1>nul pip install -r "${work_dir}/python/requirements.txt" 1>nul diff --git a/tools/check_abi_compatibility.py b/tools/check_abi_compatibility.py new file mode 100644 index 00000000000000..d7bc3497595590 --- /dev/null +++ b/tools/check_abi_compatibility.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python + +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Check Linux wheel ABI compatibility by comparing protected ELF symbols. + +The check is intentionally one-way: symbols added by a PR are allowed, while +protected symbols present in the base wheel must still exist in the PR wheel. +""" + +from __future__ import annotations + +import argparse +import glob +import os +import shutil +import subprocess +import sys +import tempfile +import zipfile +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + +WHEEL_LIBRARY_PATHS = ( + "paddle/base/libpaddle.so", + "paddle/libs/libphi.so", + "paddle/libs/libphi_core.so", + "paddle/libs/libphi_gpu.so", +) + +DEFINED_DYNAMIC_SYMBOL_TYPES = {"FUNC", "OBJECT"} + +PROTECTED_CXX_PREFIXES = ( + "phi::", + "paddle::", + "c10::", + "at::", + "torch::", +) + +PROTECTED_C_SYMBOL_PREFIXES = ( + "PD_", + "Paddle", + "PyInit_", + "paddle_", +) + +PROTECTED_MANGLED_CXX_PREFIXES = ( + "_ZN2at", + "_ZN3c10", + "_ZN3phi", + "_ZN5torch", + "_ZN6paddle", +) + + +@dataclass(frozen=True) +class DynamicSymbol: + name: str + symbol_type: str + bind: str + section: str + demangled_name: str + + +@dataclass(frozen=True) +class RemovedSymbol: + library: str + name: str + demangled_name: str + + +@dataclass(frozen=True) +class MissingLibrary: + library: str + + +def strip_elf_symbol_version(symbol_name: str) -> str: + if "@@" in symbol_name: + return symbol_name.split("@@", 1)[0] + if "@" in symbol_name: + return symbol_name.split("@", 1)[0] + return symbol_name + + +def parse_readelf_dynamic_symbols(readelf_output: str) -> list[DynamicSymbol]: + symbols = [] + for line in readelf_output.splitlines(): + fields = line.split() + if len(fields) < 8 or not fields[0].endswith(":"): + continue + symbol_type = fields[3] + bind = fields[4] + section = fields[6] + name = fields[7] + if ( + bind != "GLOBAL" + or section == "UND" + or symbol_type not in DEFINED_DYNAMIC_SYMBOL_TYPES + ): + continue + symbols.append( + DynamicSymbol( + name=name, + symbol_type=symbol_type, + bind=bind, + section=section, + demangled_name=strip_elf_symbol_version(name), + ) + ) + return symbols + + +def demangle_symbol_names(symbol_names: Iterable[str]) -> dict[str, str]: + unique_names = sorted( + {strip_elf_symbol_version(name) for name in symbol_names} + ) + if not unique_names: + return {} + cxxfilt = shutil.which("c++filt") + if cxxfilt is None: + return {name: name for name in unique_names} + + try: + result = subprocess.run( + [cxxfilt], + input="\n".join(unique_names), + text=True, + capture_output=True, + check=True, + ) + except (OSError, subprocess.CalledProcessError): + return {name: name for name in unique_names} + + demangled = result.stdout.splitlines() + if len(demangled) != len(unique_names): + return {name: name for name in unique_names} + return dict(zip(unique_names, demangled)) + + +def attach_demangled_names( + symbols: Iterable[DynamicSymbol], +) -> list[DynamicSymbol]: + symbol_list = list(symbols) + demangled_names = demangle_symbol_names( + symbol.name for symbol in symbol_list + ) + return [ + DynamicSymbol( + name=symbol.name, + symbol_type=symbol.symbol_type, + bind=symbol.bind, + section=symbol.section, + demangled_name=demangled_names.get( + strip_elf_symbol_version(symbol.name), symbol.demangled_name + ), + ) + for symbol in symbol_list + ] + + +def is_protected_paddle_abi_symbol(symbol: DynamicSymbol) -> bool: + demangled = symbol.demangled_name + if demangled.startswith(PROTECTED_CXX_PREFIXES): + return True + + raw_name = strip_elf_symbol_version(symbol.name) + return raw_name.startswith( + PROTECTED_C_SYMBOL_PREFIXES + PROTECTED_MANGLED_CXX_PREFIXES + ) + + +def protected_symbols_by_name( + symbols: Iterable[DynamicSymbol], +) -> dict[str, DynamicSymbol]: + return { + symbol.name: symbol + for symbol in symbols + if is_protected_paddle_abi_symbol(symbol) + } + + +def read_dynamic_symbols(library_path: str) -> list[DynamicSymbol]: + try: + result = subprocess.run( + ["readelf", "--dyn-syms", "-W", library_path], + text=True, + capture_output=True, + check=True, + ) + except FileNotFoundError as exc: + raise RuntimeError("readelf is required to check ABI symbols") from exc + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"Failed to read dynamic symbols from {library_path}:\n{exc.stderr}" + ) from exc + + return attach_demangled_names(parse_readelf_dynamic_symbols(result.stdout)) + + +def extract_wheel_libraries( + wheel_path: str, library_paths: Iterable[str], output_dir: str +) -> dict[str, str]: + extracted_libraries = {} + with zipfile.ZipFile(wheel_path) as wheel: + wheel_entries = set(wheel.namelist()) + for library_path in library_paths: + if library_path not in wheel_entries: + continue + extracted_path = wheel.extract(library_path, output_dir) + extracted_libraries[library_path] = extracted_path + return extracted_libraries + + +def compare_library_symbols( + library: str, + base_symbols: Iterable[DynamicSymbol] | None, + pr_symbols: Iterable[DynamicSymbol] | None, +) -> list[RemovedSymbol | MissingLibrary]: + if base_symbols is None: + return [] + if pr_symbols is None: + return [MissingLibrary(library=library)] + + base_protected_symbols = protected_symbols_by_name(base_symbols) + pr_protected_symbols = protected_symbols_by_name(pr_symbols) + removed_names = sorted( + set(base_protected_symbols) - set(pr_protected_symbols) + ) + return [ + RemovedSymbol( + library=library, + name=name, + demangled_name=base_protected_symbols[name].demangled_name, + ) + for name in removed_names + ] + + +def resolve_wheel_path(pattern: str, label: str) -> str: + matches = sorted(glob.glob(pattern)) + if len(matches) != 1: + raise RuntimeError( + f"Expected exactly one {label} wheel matching {pattern}, " + f"but found {len(matches)}: {matches}" + ) + return matches[0] + + +def compare_wheel_abi( + base_wheel: str, pr_wheel: str, library_paths: Iterable[str] +) -> list[RemovedSymbol | MissingLibrary]: + with tempfile.TemporaryDirectory(prefix="paddle_abi_check_") as temp_dir: + base_dir = os.path.join(temp_dir, "base") + pr_dir = os.path.join(temp_dir, "pr") + base_libraries = extract_wheel_libraries( + base_wheel, library_paths, base_dir + ) + pr_libraries = extract_wheel_libraries(pr_wheel, library_paths, pr_dir) + + issues: list[RemovedSymbol | MissingLibrary] = [] + for library in library_paths: + base_path = base_libraries.get(library) + pr_path = pr_libraries.get(library) + base_symbols = ( + read_dynamic_symbols(base_path) + if base_path is not None + else None + ) + pr_symbols = ( + read_dynamic_symbols(pr_path) if pr_path is not None else None + ) + issues.extend( + compare_library_symbols(library, base_symbols, pr_symbols) + ) + return issues + + +def format_issues( + issues: Iterable[RemovedSymbol | MissingLibrary], max_report: int +) -> str: + issue_list = list(issues) + lines = [ + "ABI compatibility check failed.", + "The PR wheel removed protected dynamic symbols that exist in the base " + "wheel. Removing these symbols can break downstream wheels or shared " + "libraries compiled against the base branch.", + "", + ] + for issue in issue_list[:max_report]: + if isinstance(issue, MissingLibrary): + lines.extend( + [ + f"Library: {issue.library}", + " PR wheel is missing this library, but the base wheel " + "contains it.", + "", + ] + ) + else: + lines.extend( + [ + f"Library: {issue.library}", + f" Raw symbol: {issue.name}", + f" Demangled: {issue.demangled_name}", + "", + ] + ) + + omitted_count = len(issue_list) - max_report + if omitted_count > 0: + lines.append(f"... omitted {omitted_count} additional removed symbols.") + return "\n".join(lines) + + +def parse_args(argv: list[str]) -> argparse.Namespace: + paddle_root = os.environ.get("PADDLE_ROOT", os.getcwd()) + parser = argparse.ArgumentParser( + description="Check Linux wheel ABI compatibility for Paddle symbols." + ) + parser.add_argument( + "--base-wheel", + default=os.path.join(paddle_root, "build/dev_whl/*.whl"), + help="Base branch wheel path or glob pattern.", + ) + parser.add_argument( + "--pr-wheel", + default=os.path.join(paddle_root, "build/pr_whl/*.whl"), + help="PR wheel path or glob pattern.", + ) + parser.add_argument( + "--max-report", + type=int, + default=200, + help="Maximum number of ABI issues to print.", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(sys.argv[1:] if argv is None else argv) + try: + base_wheel = resolve_wheel_path(args.base_wheel, "base") + pr_wheel = resolve_wheel_path(args.pr_wheel, "PR") + issues = compare_wheel_abi(base_wheel, pr_wheel, WHEEL_LIBRARY_PATHS) + except RuntimeError as exc: + print(f"ABI compatibility check failed: {exc}", file=sys.stderr) + return 1 + + if issues: + print(format_issues(issues, args.max_report), file=sys.stderr) + return 1 + + print("ABI compatibility check passed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/test_check_abi_compatibility.py b/tools/test_check_abi_compatibility.py new file mode 100644 index 00000000000000..5053fd65f17057 --- /dev/null +++ b/tools/test_check_abi_compatibility.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +try: + from check_abi_compatibility import ( + DynamicSymbol, + MissingLibrary, + RemovedSymbol, + compare_library_symbols, + is_protected_paddle_abi_symbol, + parse_readelf_dynamic_symbols, + ) +except ModuleNotFoundError: + from tools.check_abi_compatibility import ( + DynamicSymbol, + MissingLibrary, + RemovedSymbol, + compare_library_symbols, + is_protected_paddle_abi_symbol, + parse_readelf_dynamic_symbols, + ) + + +def make_symbol(name, demangled_name=None, bind="GLOBAL", section="12"): + return DynamicSymbol( + name=name, + symbol_type="FUNC", + bind=bind, + section=section, + demangled_name=demangled_name or name, + ) + + +class TestParseReadelfDynamicSymbols(unittest.TestCase): + def test_ignores_weak_undefined_and_local_symbols(self): + readelf_output = """ +Symbol table '.dynsym' contains 5 entries: + Num: Value Size Type Bind Vis Ndx Name + 1: 0000000000001000 42 FUNC GLOBAL DEFAULT 12 _ZN3c1017get_default_dtypeEv + 2: 0000000000001010 42 FUNC WEAK DEFAULT 12 _ZN3c104weakEv + 3: 0000000000000000 0 FUNC GLOBAL DEFAULT UND _ZN3c107missingEv + 4: 0000000000001020 42 FUNC LOCAL DEFAULT 12 _ZN3c105localEv + 5: 0000000000001030 8 OBJECT GLOBAL DEFAULT 13 _ZN3phi3barE +""" + symbols = parse_readelf_dynamic_symbols(readelf_output) + self.assertEqual( + [symbol.name for symbol in symbols], + ["_ZN3c1017get_default_dtypeEv", "_ZN3phi3barE"], + ) + + +class TestProtectedSymbols(unittest.TestCase): + def test_detects_protected_cxx_namespaces(self): + self.assertTrue( + is_protected_paddle_abi_symbol( + make_symbol( + "_ZN3c1017get_default_dtypeEv", + "c10::get_default_dtype()", + ) + ) + ) + self.assertTrue( + is_protected_paddle_abi_symbol( + make_symbol("_ZN3phi3barEv", "phi::bar()") + ) + ) + self.assertTrue( + is_protected_paddle_abi_symbol( + make_symbol("_ZN5torch4cuda11synchronizeEv") + ) + ) + + def test_detects_relevant_c_and_python_entrypoints(self): + self.assertTrue( + is_protected_paddle_abi_symbol(make_symbol("PyInit_libpaddle")) + ) + self.assertTrue( + is_protected_paddle_abi_symbol(make_symbol("PD_ConfigCreate")) + ) + + def test_ignores_third_party_symbols(self): + self.assertFalse( + is_protected_paddle_abi_symbol(make_symbol("XXH32", "XXH32")) + ) + self.assertFalse( + is_protected_paddle_abi_symbol( + make_symbol("_ZN4YAML7EmitterC1Ev", "YAML::Emitter::Emitter()") + ) + ) + + +class TestCompareLibrarySymbols(unittest.TestCase): + def test_added_symbols_do_not_fail(self): + base_symbols = [ + make_symbol( + "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" + ) + ] + pr_symbols = [ + *base_symbols, + make_symbol( + "_ZN3c1017set_default_dtypeEv", "c10::set_default_dtype()" + ), + ] + + issues = compare_library_symbols( + "paddle/libs/libphi_core.so", base_symbols, pr_symbols + ) + + self.assertEqual(issues, []) + + def test_removed_protected_symbol_fails(self): + base_symbols = [ + make_symbol( + "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" + ) + ] + + issues = compare_library_symbols( + "paddle/libs/libphi_core.so", base_symbols, [] + ) + + self.assertEqual( + issues, + [ + RemovedSymbol( + library="paddle/libs/libphi_core.so", + name="_ZN3c1017get_default_dtypeEv", + demangled_name="c10::get_default_dtype()", + ) + ], + ) + + def test_removed_third_party_symbol_does_not_fail(self): + base_symbols = [make_symbol("XXH32", "XXH32")] + + issues = compare_library_symbols( + "paddle/base/libpaddle.so", base_symbols, [] + ) + + self.assertEqual(issues, []) + + def test_missing_pr_library_fails_when_base_has_library(self): + base_symbols = [ + make_symbol( + "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" + ) + ] + + issues = compare_library_symbols( + "paddle/libs/libphi_core.so", base_symbols, None + ) + + self.assertEqual( + issues, [MissingLibrary(library="paddle/libs/libphi_core.so")] + ) + + def test_missing_base_library_does_not_fail(self): + pr_symbols = [ + make_symbol( + "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" + ) + ] + + issues = compare_library_symbols( + "paddle/libs/libphi_core.so", None, pr_symbols + ) + + self.assertEqual(issues, []) + + +if __name__ == "__main__": + unittest.main() From aa1d65ee9dc3037c740fe532054877a5101fac62 Mon Sep 17 00:00:00 2001 From: youge325 Date: Tue, 28 Apr 2026 16:16:01 +0800 Subject: [PATCH 3/3] Revert "Add ABI symbol compatibility check" This reverts commit 586c99ebd628b7a3c1f134b19f79474e2ad0c1ea. --- ci/static_check.sh | 17 -- tools/check_abi_compatibility.py | 374 -------------------------- tools/test_check_abi_compatibility.py | 188 ------------- 3 files changed, 579 deletions(-) delete mode 100644 tools/check_abi_compatibility.py delete mode 100644 tools/test_check_abi_compatibility.py diff --git a/ci/static_check.sh b/ci/static_check.sh index 98cb60220995c5..9682a6ae48da47 100644 --- a/ci/static_check.sh +++ b/ci/static_check.sh @@ -149,21 +149,6 @@ function exec_samplecode_checking() { fi } -function exec_abi_compatibility_check() { - if [ "$(uname -s)" != "Linux" ]; then - echo "Skip ABI compatibility check on non-Linux platform." - return - fi - - python ${PADDLE_ROOT}/tools/check_abi_compatibility.py \ - --base-wheel "${PADDLE_ROOT}/build/dev_whl/*.whl" \ - --pr-wheel "${PADDLE_ROOT}/build/pr_whl/*.whl" - abi_check_error=$? - if [ "$abi_check_error" != "0" ]; then - exit $abi_check_error - fi -} - export PATH=/usr/local/python3.10.0/bin:/usr/local/python3.10.0/include:/usr/local/bin:${PATH} echo "export PATH=${PATH}" >> ~/.bashrc export LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:$LD_LIBRARY_PATH @@ -173,8 +158,6 @@ ln -sf $(which python${PY_VERSION}) /usr/bin/python ln -sf $(which pip${PY_VERSION}) /usr/local/bin/pip mkdir -p /home/data/cfs/.ccache/static-check -exec_abi_compatibility_check - pip config set global.cache-dir "/home/data/cfs/.cache/pip" pip install --upgrade pip 1>nul pip install -r "${work_dir}/python/requirements.txt" 1>nul diff --git a/tools/check_abi_compatibility.py b/tools/check_abi_compatibility.py deleted file mode 100644 index d7bc3497595590..00000000000000 --- a/tools/check_abi_compatibility.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Check Linux wheel ABI compatibility by comparing protected ELF symbols. - -The check is intentionally one-way: symbols added by a PR are allowed, while -protected symbols present in the base wheel must still exist in the PR wheel. -""" - -from __future__ import annotations - -import argparse -import glob -import os -import shutil -import subprocess -import sys -import tempfile -import zipfile -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Iterable - -WHEEL_LIBRARY_PATHS = ( - "paddle/base/libpaddle.so", - "paddle/libs/libphi.so", - "paddle/libs/libphi_core.so", - "paddle/libs/libphi_gpu.so", -) - -DEFINED_DYNAMIC_SYMBOL_TYPES = {"FUNC", "OBJECT"} - -PROTECTED_CXX_PREFIXES = ( - "phi::", - "paddle::", - "c10::", - "at::", - "torch::", -) - -PROTECTED_C_SYMBOL_PREFIXES = ( - "PD_", - "Paddle", - "PyInit_", - "paddle_", -) - -PROTECTED_MANGLED_CXX_PREFIXES = ( - "_ZN2at", - "_ZN3c10", - "_ZN3phi", - "_ZN5torch", - "_ZN6paddle", -) - - -@dataclass(frozen=True) -class DynamicSymbol: - name: str - symbol_type: str - bind: str - section: str - demangled_name: str - - -@dataclass(frozen=True) -class RemovedSymbol: - library: str - name: str - demangled_name: str - - -@dataclass(frozen=True) -class MissingLibrary: - library: str - - -def strip_elf_symbol_version(symbol_name: str) -> str: - if "@@" in symbol_name: - return symbol_name.split("@@", 1)[0] - if "@" in symbol_name: - return symbol_name.split("@", 1)[0] - return symbol_name - - -def parse_readelf_dynamic_symbols(readelf_output: str) -> list[DynamicSymbol]: - symbols = [] - for line in readelf_output.splitlines(): - fields = line.split() - if len(fields) < 8 or not fields[0].endswith(":"): - continue - symbol_type = fields[3] - bind = fields[4] - section = fields[6] - name = fields[7] - if ( - bind != "GLOBAL" - or section == "UND" - or symbol_type not in DEFINED_DYNAMIC_SYMBOL_TYPES - ): - continue - symbols.append( - DynamicSymbol( - name=name, - symbol_type=symbol_type, - bind=bind, - section=section, - demangled_name=strip_elf_symbol_version(name), - ) - ) - return symbols - - -def demangle_symbol_names(symbol_names: Iterable[str]) -> dict[str, str]: - unique_names = sorted( - {strip_elf_symbol_version(name) for name in symbol_names} - ) - if not unique_names: - return {} - cxxfilt = shutil.which("c++filt") - if cxxfilt is None: - return {name: name for name in unique_names} - - try: - result = subprocess.run( - [cxxfilt], - input="\n".join(unique_names), - text=True, - capture_output=True, - check=True, - ) - except (OSError, subprocess.CalledProcessError): - return {name: name for name in unique_names} - - demangled = result.stdout.splitlines() - if len(demangled) != len(unique_names): - return {name: name for name in unique_names} - return dict(zip(unique_names, demangled)) - - -def attach_demangled_names( - symbols: Iterable[DynamicSymbol], -) -> list[DynamicSymbol]: - symbol_list = list(symbols) - demangled_names = demangle_symbol_names( - symbol.name for symbol in symbol_list - ) - return [ - DynamicSymbol( - name=symbol.name, - symbol_type=symbol.symbol_type, - bind=symbol.bind, - section=symbol.section, - demangled_name=demangled_names.get( - strip_elf_symbol_version(symbol.name), symbol.demangled_name - ), - ) - for symbol in symbol_list - ] - - -def is_protected_paddle_abi_symbol(symbol: DynamicSymbol) -> bool: - demangled = symbol.demangled_name - if demangled.startswith(PROTECTED_CXX_PREFIXES): - return True - - raw_name = strip_elf_symbol_version(symbol.name) - return raw_name.startswith( - PROTECTED_C_SYMBOL_PREFIXES + PROTECTED_MANGLED_CXX_PREFIXES - ) - - -def protected_symbols_by_name( - symbols: Iterable[DynamicSymbol], -) -> dict[str, DynamicSymbol]: - return { - symbol.name: symbol - for symbol in symbols - if is_protected_paddle_abi_symbol(symbol) - } - - -def read_dynamic_symbols(library_path: str) -> list[DynamicSymbol]: - try: - result = subprocess.run( - ["readelf", "--dyn-syms", "-W", library_path], - text=True, - capture_output=True, - check=True, - ) - except FileNotFoundError as exc: - raise RuntimeError("readelf is required to check ABI symbols") from exc - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"Failed to read dynamic symbols from {library_path}:\n{exc.stderr}" - ) from exc - - return attach_demangled_names(parse_readelf_dynamic_symbols(result.stdout)) - - -def extract_wheel_libraries( - wheel_path: str, library_paths: Iterable[str], output_dir: str -) -> dict[str, str]: - extracted_libraries = {} - with zipfile.ZipFile(wheel_path) as wheel: - wheel_entries = set(wheel.namelist()) - for library_path in library_paths: - if library_path not in wheel_entries: - continue - extracted_path = wheel.extract(library_path, output_dir) - extracted_libraries[library_path] = extracted_path - return extracted_libraries - - -def compare_library_symbols( - library: str, - base_symbols: Iterable[DynamicSymbol] | None, - pr_symbols: Iterable[DynamicSymbol] | None, -) -> list[RemovedSymbol | MissingLibrary]: - if base_symbols is None: - return [] - if pr_symbols is None: - return [MissingLibrary(library=library)] - - base_protected_symbols = protected_symbols_by_name(base_symbols) - pr_protected_symbols = protected_symbols_by_name(pr_symbols) - removed_names = sorted( - set(base_protected_symbols) - set(pr_protected_symbols) - ) - return [ - RemovedSymbol( - library=library, - name=name, - demangled_name=base_protected_symbols[name].demangled_name, - ) - for name in removed_names - ] - - -def resolve_wheel_path(pattern: str, label: str) -> str: - matches = sorted(glob.glob(pattern)) - if len(matches) != 1: - raise RuntimeError( - f"Expected exactly one {label} wheel matching {pattern}, " - f"but found {len(matches)}: {matches}" - ) - return matches[0] - - -def compare_wheel_abi( - base_wheel: str, pr_wheel: str, library_paths: Iterable[str] -) -> list[RemovedSymbol | MissingLibrary]: - with tempfile.TemporaryDirectory(prefix="paddle_abi_check_") as temp_dir: - base_dir = os.path.join(temp_dir, "base") - pr_dir = os.path.join(temp_dir, "pr") - base_libraries = extract_wheel_libraries( - base_wheel, library_paths, base_dir - ) - pr_libraries = extract_wheel_libraries(pr_wheel, library_paths, pr_dir) - - issues: list[RemovedSymbol | MissingLibrary] = [] - for library in library_paths: - base_path = base_libraries.get(library) - pr_path = pr_libraries.get(library) - base_symbols = ( - read_dynamic_symbols(base_path) - if base_path is not None - else None - ) - pr_symbols = ( - read_dynamic_symbols(pr_path) if pr_path is not None else None - ) - issues.extend( - compare_library_symbols(library, base_symbols, pr_symbols) - ) - return issues - - -def format_issues( - issues: Iterable[RemovedSymbol | MissingLibrary], max_report: int -) -> str: - issue_list = list(issues) - lines = [ - "ABI compatibility check failed.", - "The PR wheel removed protected dynamic symbols that exist in the base " - "wheel. Removing these symbols can break downstream wheels or shared " - "libraries compiled against the base branch.", - "", - ] - for issue in issue_list[:max_report]: - if isinstance(issue, MissingLibrary): - lines.extend( - [ - f"Library: {issue.library}", - " PR wheel is missing this library, but the base wheel " - "contains it.", - "", - ] - ) - else: - lines.extend( - [ - f"Library: {issue.library}", - f" Raw symbol: {issue.name}", - f" Demangled: {issue.demangled_name}", - "", - ] - ) - - omitted_count = len(issue_list) - max_report - if omitted_count > 0: - lines.append(f"... omitted {omitted_count} additional removed symbols.") - return "\n".join(lines) - - -def parse_args(argv: list[str]) -> argparse.Namespace: - paddle_root = os.environ.get("PADDLE_ROOT", os.getcwd()) - parser = argparse.ArgumentParser( - description="Check Linux wheel ABI compatibility for Paddle symbols." - ) - parser.add_argument( - "--base-wheel", - default=os.path.join(paddle_root, "build/dev_whl/*.whl"), - help="Base branch wheel path or glob pattern.", - ) - parser.add_argument( - "--pr-wheel", - default=os.path.join(paddle_root, "build/pr_whl/*.whl"), - help="PR wheel path or glob pattern.", - ) - parser.add_argument( - "--max-report", - type=int, - default=200, - help="Maximum number of ABI issues to print.", - ) - return parser.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - args = parse_args(sys.argv[1:] if argv is None else argv) - try: - base_wheel = resolve_wheel_path(args.base_wheel, "base") - pr_wheel = resolve_wheel_path(args.pr_wheel, "PR") - issues = compare_wheel_abi(base_wheel, pr_wheel, WHEEL_LIBRARY_PATHS) - except RuntimeError as exc: - print(f"ABI compatibility check failed: {exc}", file=sys.stderr) - return 1 - - if issues: - print(format_issues(issues, args.max_report), file=sys.stderr) - return 1 - - print("ABI compatibility check passed.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tools/test_check_abi_compatibility.py b/tools/test_check_abi_compatibility.py deleted file mode 100644 index 5053fd65f17057..00000000000000 --- a/tools/test_check_abi_compatibility.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -try: - from check_abi_compatibility import ( - DynamicSymbol, - MissingLibrary, - RemovedSymbol, - compare_library_symbols, - is_protected_paddle_abi_symbol, - parse_readelf_dynamic_symbols, - ) -except ModuleNotFoundError: - from tools.check_abi_compatibility import ( - DynamicSymbol, - MissingLibrary, - RemovedSymbol, - compare_library_symbols, - is_protected_paddle_abi_symbol, - parse_readelf_dynamic_symbols, - ) - - -def make_symbol(name, demangled_name=None, bind="GLOBAL", section="12"): - return DynamicSymbol( - name=name, - symbol_type="FUNC", - bind=bind, - section=section, - demangled_name=demangled_name or name, - ) - - -class TestParseReadelfDynamicSymbols(unittest.TestCase): - def test_ignores_weak_undefined_and_local_symbols(self): - readelf_output = """ -Symbol table '.dynsym' contains 5 entries: - Num: Value Size Type Bind Vis Ndx Name - 1: 0000000000001000 42 FUNC GLOBAL DEFAULT 12 _ZN3c1017get_default_dtypeEv - 2: 0000000000001010 42 FUNC WEAK DEFAULT 12 _ZN3c104weakEv - 3: 0000000000000000 0 FUNC GLOBAL DEFAULT UND _ZN3c107missingEv - 4: 0000000000001020 42 FUNC LOCAL DEFAULT 12 _ZN3c105localEv - 5: 0000000000001030 8 OBJECT GLOBAL DEFAULT 13 _ZN3phi3barE -""" - symbols = parse_readelf_dynamic_symbols(readelf_output) - self.assertEqual( - [symbol.name for symbol in symbols], - ["_ZN3c1017get_default_dtypeEv", "_ZN3phi3barE"], - ) - - -class TestProtectedSymbols(unittest.TestCase): - def test_detects_protected_cxx_namespaces(self): - self.assertTrue( - is_protected_paddle_abi_symbol( - make_symbol( - "_ZN3c1017get_default_dtypeEv", - "c10::get_default_dtype()", - ) - ) - ) - self.assertTrue( - is_protected_paddle_abi_symbol( - make_symbol("_ZN3phi3barEv", "phi::bar()") - ) - ) - self.assertTrue( - is_protected_paddle_abi_symbol( - make_symbol("_ZN5torch4cuda11synchronizeEv") - ) - ) - - def test_detects_relevant_c_and_python_entrypoints(self): - self.assertTrue( - is_protected_paddle_abi_symbol(make_symbol("PyInit_libpaddle")) - ) - self.assertTrue( - is_protected_paddle_abi_symbol(make_symbol("PD_ConfigCreate")) - ) - - def test_ignores_third_party_symbols(self): - self.assertFalse( - is_protected_paddle_abi_symbol(make_symbol("XXH32", "XXH32")) - ) - self.assertFalse( - is_protected_paddle_abi_symbol( - make_symbol("_ZN4YAML7EmitterC1Ev", "YAML::Emitter::Emitter()") - ) - ) - - -class TestCompareLibrarySymbols(unittest.TestCase): - def test_added_symbols_do_not_fail(self): - base_symbols = [ - make_symbol( - "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" - ) - ] - pr_symbols = [ - *base_symbols, - make_symbol( - "_ZN3c1017set_default_dtypeEv", "c10::set_default_dtype()" - ), - ] - - issues = compare_library_symbols( - "paddle/libs/libphi_core.so", base_symbols, pr_symbols - ) - - self.assertEqual(issues, []) - - def test_removed_protected_symbol_fails(self): - base_symbols = [ - make_symbol( - "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" - ) - ] - - issues = compare_library_symbols( - "paddle/libs/libphi_core.so", base_symbols, [] - ) - - self.assertEqual( - issues, - [ - RemovedSymbol( - library="paddle/libs/libphi_core.so", - name="_ZN3c1017get_default_dtypeEv", - demangled_name="c10::get_default_dtype()", - ) - ], - ) - - def test_removed_third_party_symbol_does_not_fail(self): - base_symbols = [make_symbol("XXH32", "XXH32")] - - issues = compare_library_symbols( - "paddle/base/libpaddle.so", base_symbols, [] - ) - - self.assertEqual(issues, []) - - def test_missing_pr_library_fails_when_base_has_library(self): - base_symbols = [ - make_symbol( - "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" - ) - ] - - issues = compare_library_symbols( - "paddle/libs/libphi_core.so", base_symbols, None - ) - - self.assertEqual( - issues, [MissingLibrary(library="paddle/libs/libphi_core.so")] - ) - - def test_missing_base_library_does_not_fail(self): - pr_symbols = [ - make_symbol( - "_ZN3c1017get_default_dtypeEv", "c10::get_default_dtype()" - ) - ] - - issues = compare_library_symbols( - "paddle/libs/libphi_core.so", None, pr_symbols - ) - - self.assertEqual(issues, []) - - -if __name__ == "__main__": - unittest.main()