Skip to content

Commit e7b87fe

Browse files
authored
[ARITH] Add optional Z3-backed proving to Analyzer (#19667)
## Summary This PR adds a Z3 SMT solver backend to `tvm::arith::Analyzer` for stronger integer arithmetic proving. The integration is guarded by `USE_Z3`, which defaults to `AUTO`. In the default mode, TVM enables Z3 when the static Z3 development artifacts are available and otherwise builds the conservative stub implementation. When Z3 is enabled, `Analyzer::CanProve` runs the existing TVM arithmetic analysis path first, then falls back to Z3 only when the existing analyzers cannot prove the predicate and the requested strength is `kSymbolicBound`. Z3 is linked statically from the PyPI `z3-static` package, so `libtvm` does not need a runtime `libz3` dependency. ## Features - Z3 build support through `USE_Z3`, defaulting to `AUTO`. - A new `arith::Z3Prover` sub-analyzer owned by `arith::Analyzer`. - SMT-LIB2 export for debugging and external solver reproduction. - Python debug/config APIs: `Analyzer.get_smtlib2`, `Analyzer.set_z3_timeout_ms`, `Analyzer.set_z3_rlimit`, and `Analyzer.get_z3_stats`. - C++ APIs for proving, binding, constraints, stats, model inspection, and satisfying-value counting. - Scalar integer, unsigned integer, and boolean expression translation to Z3. - Support for arithmetic, comparisons, boolean operators, `min`, `max`, `select`, `if_then_else`, `let`, casts, truncated division/modulo, floor division/modulo, and selected bitwise/shift operations. - Deterministic solver control using Z3 `rlimit`, with `random_seed` fixed to `42`. - Thread-local Z3 context sharing to reduce initialization overhead while keeping thread safety. - A disabled-mode stub implementation that returns conservative results when Z3 is not built. ## Implementation Notes - The real and stub implementations live in `src/arith/z3_prover.cc`, selected by the `TVM_USE_Z3` macro from `cmake/modules/contrib/Z3.cmake`. - `cmake/modules/contrib/Z3.cmake` first resolves the PIC static `libz3` layout provided by `z3-static` using its `z3_static.get_cmake_dir()` helper, then falls back to a custom `Z3_DIR` or `CMAKE_PREFIX_PATH` installation. - `USE_Z3=ON` requires Z3 to be found, while `USE_Z3=AUTO` allows source builds and CI jobs without Z3 artifacts to continue with the stub. - The Z3 fallback is exception-safe and gated behind `kSymbolicBound`, so the common `kDefault` path does not pay solver cost. - TVM `Div` and `Mod` are translated with truncating helpers rather than Z3's Euclidean operators to stay sound for negative dividends. - Shift handling relies on Z3's native bit-vector semantics and does not add hard assertions to the shared solver. ## References The implementation is based on the Z3 analyzer integration used in TileLang's TVM fork, with the upstream port kept scoped to TVM's arithmetic analyzer. - [tile-ai/tilelang#1367](tile-ai/tilelang#1367) - [tile-ai/tilelang#1458](tile-ai/tilelang#1458) - [tile-ai/tilelang#2216](tile-ai/tilelang#2216) - [tile-ai#22](tile-ai#22) - [tile-ai#24](tile-ai#24) - [Original TileLang TVM commit](tile-ai@e633295) --------- Signed-off-by: Ubospica <ubospica@gmail.com>
1 parent d321035 commit e7b87fe

9 files changed

Lines changed: 1982 additions & 5 deletions

File tree

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
8989
# Contrib library options
9090
tvm_option(USE_BLAS "The blas library to be linked" none)
9191
tvm_option(USE_AMX "Enable Intel AMX" OFF)
92+
tvm_option(USE_Z3 "Build with Z3 SMT solver support" AUTO)
9293
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
9394
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
9495
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
@@ -459,6 +460,7 @@ include(cmake/modules/contrib/AMX.cmake)
459460
include(cmake/modules/contrib/CUTLASS.cmake)
460461
include(cmake/modules/contrib/Random.cmake)
461462
include(cmake/modules/contrib/Sort.cmake)
463+
include(cmake/modules/contrib/Z3.cmake)
462464
include(cmake/modules/contrib/CoreML.cmake)
463465
include(cmake/modules/contrib/TensorRT.cmake)
464466
include(cmake/modules/contrib/NNAPI.cmake)
@@ -545,6 +547,9 @@ add_library(tvm_objs OBJECT ${COMPILER_SRCS})
545547
add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS})
546548
target_link_libraries(tvm_objs PUBLIC tvm_ffi_header)
547549
target_link_libraries(tvm_runtime_objs PUBLIC tvm_ffi_header)
550+
if(TARGET tvm_llvm_header)
551+
target_link_libraries(tvm_objs PUBLIC tvm_llvm_header)
552+
endif()
548553

549554
include(GNUInstallDirs)
550555

cmake/modules/LLVM.cmake

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
3434
endif()
3535
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})
3636
add_definitions(${LLVM_DEFINITIONS})
37+
add_library(tvm_llvm_header INTERFACE)
38+
if(MSVC)
39+
# MSVC treats GCC-style -isystem operands as source files.
40+
target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS})
41+
target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS})
42+
else()
43+
set(TVM_LLVM_INCLUDE_FLAGS "")
44+
foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS)
45+
string(STRIP "${__llvm_include_dir}" __llvm_include_dir)
46+
list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}")
47+
endforeach()
48+
target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS} ${LLVM_DEFINITIONS})
49+
endif()
3750
message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION})
3851
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
3952
# Set flags that are only needed for LLVM target

cmake/modules/contrib/Z3.cmake

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the
19+
# src/arith/*.cc glob). It compiles a conservative stub by default and switches
20+
# to the real Z3 implementation only when the TVM_USE_Z3 macro is defined below.
21+
if(${USE_Z3} MATCHES ${IS_FALSE_PATTERN})
22+
return()
23+
endif()
24+
25+
set(TVM_Z3_REQUIRED TRUE)
26+
if("${USE_Z3}" MATCHES "^[Aa][Uu][Tt][Oo]$")
27+
set(TVM_Z3_REQUIRED FALSE)
28+
endif()
29+
30+
# Default lookup: the PIC static Z3 library shipped by the PyPI `z3-static`
31+
# package (headers + libz3.a + Z3 CMake package files). Linking it statically
32+
# keeps libtvm free of a runtime libz3 dependency. Users can override the
33+
# lookup by setting Z3_DIR/CMAKE_PREFIX_PATH to any Z3 installation (e.g. a
34+
# shared system Z3).
35+
if(NOT Z3_DIR)
36+
find_package(Python3 COMPONENTS Interpreter QUIET)
37+
if(Python3_EXECUTABLE)
38+
execute_process(
39+
COMMAND
40+
"${Python3_EXECUTABLE}" -m z3_static.config --cmake-dir
41+
OUTPUT_VARIABLE Z3_STATIC_CMAKE_DIR
42+
OUTPUT_STRIP_TRAILING_WHITESPACE
43+
ERROR_QUIET
44+
RESULT_VARIABLE Z3_STATIC_RESULT
45+
)
46+
if(Z3_STATIC_RESULT EQUAL 0 AND EXISTS "${Z3_STATIC_CMAKE_DIR}")
47+
set(Z3_DIR "${Z3_STATIC_CMAKE_DIR}")
48+
endif()
49+
endif()
50+
endif()
51+
52+
find_package(Z3 CONFIG QUIET)
53+
if(NOT Z3_FOUND AND NOT TARGET z3::libz3 AND NOT TARGET Z3::libz3)
54+
find_package(Z3 QUIET)
55+
endif()
56+
57+
if(TARGET z3::libz3 OR TARGET Z3::libz3)
58+
if(TARGET z3::libz3)
59+
set(Z3_TARGET z3::libz3)
60+
else()
61+
set(Z3_TARGET Z3::libz3)
62+
endif()
63+
get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET} INTERFACE_INCLUDE_DIRECTORIES)
64+
if(Z3_TARGET_INCLUDE_DIRS)
65+
include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS})
66+
endif()
67+
list(APPEND TVM_LINKER_LIBS ${Z3_TARGET})
68+
elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY))
69+
if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS)
70+
set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS})
71+
endif()
72+
if(NOT Z3_LIBRARY AND Z3_LIBRARIES)
73+
set(Z3_LIBRARY ${Z3_LIBRARIES})
74+
endif()
75+
if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
76+
message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was not found.")
77+
endif()
78+
include_directories(SYSTEM ${Z3_INCLUDE_DIR})
79+
list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY})
80+
else()
81+
if(TVM_Z3_REQUIRED)
82+
message(FATAL_ERROR
83+
"USE_Z3 is ON, but Z3 was not found. Install the static Z3 development "
84+
"package with `pip install 'z3-static>=4.16.0.post1'`, or point "
85+
"Z3_DIR/CMAKE_PREFIX_PATH at a Z3 installation.")
86+
endif()
87+
message(STATUS "Build without Z3 SMT solver support")
88+
return()
89+
endif()
90+
91+
# Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file.
92+
add_compile_definitions(TVM_USE_Z3)
93+
message(STATUS "Build with Z3 SMT solver support")

include/tvm/arith/analyzer.h

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/arith/int_set.h>
2828
#include <tvm/ffi/cast.h>
2929
#include <tvm/ffi/reflection/registry.h>
30+
#include <tvm/ffi/string.h>
3031
#include <tvm/ir/expr.h>
3132
#include <tvm/ir/with_context.h>
3233

@@ -588,6 +589,110 @@ class IntSetAnalyzer {
588589
Impl* impl_;
589590
};
590591

592+
class Z3Prover {
593+
public:
594+
/*!
595+
* \brief Update binding of var to a new expression.
596+
*
597+
* \param var The variable of interest.
598+
* \param new_range The range of allowed values for this var.
599+
* \param allow_override whether we allow override of existing information.
600+
*/
601+
TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
602+
603+
/*!
604+
* \brief Update binding of var to a new expression.
605+
*
606+
* \param var The variable of interest.
607+
* \param expr The bound expression.
608+
* \param allow_override whether we allow override of existing information.
609+
*/
610+
TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
611+
612+
/*!
613+
* \brief Whether the Z3 backend is compiled into this build (USE_Z3=ON).
614+
*
615+
* \return true if the real Z3 prover is available, false for the stub.
616+
*/
617+
TVM_DLL bool IsEnabled() const;
618+
619+
/*!
620+
* \brief Whether can we prove expr is always true.
621+
*
622+
* \param expr The expression.
623+
* \return Whether we can prove it.
624+
*/
625+
TVM_DLL bool CanProve(const PrimExpr& expr);
626+
627+
/*!
628+
* \brief Update the internal state to enter constraint.
629+
*
630+
* \param constraint A constraint expression.
631+
* \return an exit function that must be called to cleanup the constraint can be nullptr.
632+
*/
633+
std::function<void()> EnterConstraint(const PrimExpr& constraint);
634+
635+
/*!
636+
* \brief Get the SMTLIB2 representation of the current context.
637+
*
638+
* \param expr The optional expression to check.
639+
* \return The SMTLIB2 string.
640+
*/
641+
ffi::String GetSMTLIB2(const ffi::Optional<PrimExpr> expr);
642+
643+
/*!
644+
* \brief Get statistics about Z3 prover.
645+
*
646+
* \return The statistics string.
647+
*/
648+
ffi::String GetStats();
649+
650+
/*!
651+
* \brief Set timeout in milliseconds for Z3 prover.
652+
*
653+
* \param timeout_ms The timeout in milliseconds.
654+
*/
655+
void SetTimeoutMs(unsigned timeout_ms);
656+
657+
/*!
658+
* \brief Set resource limitation for Z3 prover.
659+
*
660+
* \param rlimit the resource limitation.
661+
*/
662+
void SetRLimit(unsigned rlimit);
663+
664+
/*!
665+
* \brief Get the Z3 model for the given expression if satisfiable.
666+
*
667+
* \param expr The expression to get the model for.
668+
* \return The model as a string.
669+
*/
670+
ffi::String GetModel(const PrimExpr& expr);
671+
672+
/*!
673+
* \brief Count the number of integer values that satisfy the current constraints.
674+
*
675+
* This method uses Z3's model enumeration to count how many distinct values of
676+
* the given variable satisfy all current constraints.
677+
*
678+
* \param var The variable to count satisfying values for.
679+
* \param max_count Maximum number of solutions to enumerate.
680+
* \param min_consecutive Minimum consecutive count requirement.
681+
* \return The number of distinct values that satisfy the constraints, or a negative error code.
682+
*/
683+
TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048,
684+
int64_t min_consecutive = 1);
685+
686+
private:
687+
friend class AnalyzerObj;
688+
friend class Analyzer;
689+
explicit Z3Prover(AnalyzerObj* parent);
690+
TVM_DLL ~Z3Prover();
691+
void CopyFrom(const Z3Prover& other);
692+
class Impl;
693+
Impl* impl_;
694+
};
695+
591696
/*!
592697
* \brief Analyzer that contains bunch of sub-analyzers.
593698
*
@@ -612,6 +717,8 @@ class TVM_DLL AnalyzerObj : public ffi::Object {
612717
IntSetAnalyzer int_set;
613718
/*! \brief sub-analyzer transitive comparisons */
614719
TransitiveComparisonAnalyzer transitive_comparisons;
720+
/*! \brief sub-analyzer using Z3 */
721+
Z3Prover z3_prover;
615722
/*! \brief constructor */
616723
AnalyzerObj();
617724
/*!
@@ -810,7 +917,16 @@ class ConstraintContext {
810917
* \param constraint The constraint to be applied.
811918
*/
812919
ConstraintContext(const Analyzer& analyzer, PrimExpr constraint)
813-
: analyzer_(analyzer), constraint_(constraint) {}
920+
: ConstraintContext(analyzer, std::move(constraint), false) {}
921+
/*!
922+
* \brief Construct a constraint context.
923+
* \param analyzer The analyzer whose context is updated. The context
924+
* keeps a reference to the analyzer while the scope is active.
925+
* \param constraint The constraint to be applied.
926+
* \param is_assume Whether the constraint comes from an assumption.
927+
*/
928+
ConstraintContext(const Analyzer& analyzer, PrimExpr constraint, bool is_assume)
929+
: analyzer_(analyzer), constraint_(std::move(constraint)), is_assume_(is_assume) {}
814930
/*!
815931
* \brief Construct a constraint context from a borrowed analyzer object.
816932
* \param analyzer The borrowed analyzer object.
@@ -819,7 +935,15 @@ class ConstraintContext {
819935
* This overload is for internal callers that already operate on AnalyzerObj*.
820936
*/
821937
ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint)
822-
: ConstraintContext(ffi::GetRef<Analyzer>(analyzer), std::move(constraint)) {}
938+
: ConstraintContext(ffi::GetRef<Analyzer>(analyzer), std::move(constraint), false) {}
939+
/*!
940+
* \brief Construct a constraint context from a borrowed analyzer object.
941+
* \param analyzer The borrowed analyzer object.
942+
* \param constraint The constraint to be applied.
943+
* \param is_assume Whether the constraint comes from an assumption.
944+
*/
945+
ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint, bool is_assume)
946+
: ConstraintContext(ffi::GetRef<Analyzer>(analyzer), std::move(constraint), is_assume) {}
823947
// enter the scope.
824948
void EnterWithScope();
825949
// exit the scope.
@@ -830,6 +954,8 @@ class ConstraintContext {
830954
PrimExpr constraint_;
831955
/*! \brief functions to be called in recovery */
832956
std::vector<std::function<void()>> recovery_functions_;
957+
/*! \brief Whether the constraint comes from an assumption. */
958+
bool is_assume_;
833959
};
834960

835961
} // namespace arith

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
# under the License.
1717

1818
[build-system]
19-
requires = ["scikit-build-core>=0.11", "setuptools-scm>=8"]
19+
# z3-static ships the PIC static libz3 + headers consumed by USE_Z3=ON.
20+
requires = [
21+
"scikit-build-core>=0.11",
22+
"setuptools-scm>=8",
23+
"z3-static>=4.16.0.post1",
24+
]
2025
build-backend = "scikit_build_core.build"
2126

2227
[project]
@@ -141,6 +146,8 @@ logging.level = "INFO"
141146
[tool.scikit-build.cmake.define]
142147
TVM_BUILD_PYTHON_MODULE = "ON"
143148
USE_CUDA = "OFF"
149+
# Statically link Z3 from the z3-static build dependency by default.
150+
USE_Z3 = "ON"
144151
BUILD_TESTING = "OFF"
145152

146153
[tool.setuptools_scm]

0 commit comments

Comments
 (0)