Skip to content

Commit 39f808e

Browse files
committed
update
Signed-off-by: Ubospica <ubospica@gmail.com>
1 parent 6fca18c commit 39f808e

14 files changed

Lines changed: 903 additions & 167 deletions

File tree

3rdparty/cnpy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 4e8810b1a8637695171ed346ce68f6984e585ef4

3rdparty/dmlc-core

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c

3rdparty/rang

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762

3rdparty/zlib

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit ef24c4c7502169f016dcd2a26923dbaf3216748c

cmake/modules/contrib/Z3.cmake

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
# src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the
19+
# src/arith/*.cc glob). It compiles a conservative stub by default and switches
20+
# to the real Z3 implementation only when the TVM_USE_Z3 macro is defined below.
1821
if(NOT USE_Z3)
19-
list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc)
2022
return()
2123
endif()
2224

@@ -73,4 +75,6 @@ else()
7375
message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.")
7476
endif()
7577

76-
list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc)
78+
# Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file.
79+
add_compile_definitions(TVM_USE_Z3)
80+
message(STATUS "Build with Z3 SMT solver support")

include/tvm/arith/analyzer.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ class RewriteSimplifier {
296296
*
297297
* \return an exit function that must be called to cleanup the constraint can be nullptr.
298298
*/
299-
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);
299+
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
300300

301301
/*! \brief Flags to enable more computationally-intensive simplifications
302302
*
@@ -555,8 +555,8 @@ class ConstraintContext {
555555
* \param analyzer The analyzer.
556556
* \param constraint The constraint to be applied.
557557
*/
558-
ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume = false)
559-
: analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {}
558+
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
559+
: analyzer_(analyzer), constraint_(constraint) {}
560560
// enter the scope.
561561
void EnterWithScope();
562562
// exit the scope.
@@ -567,7 +567,6 @@ class ConstraintContext {
567567
PrimExpr constraint_;
568568
/*! \brief functions to be called in recovery */
569569
std::vector<std::function<void()>> recovery_functions_;
570-
bool is_assume_;
571570
};
572571

573572
/*!
@@ -644,6 +643,13 @@ class Z3Prover {
644643
*/
645644
TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
646645

646+
/*!
647+
* \brief Whether the Z3 backend is compiled into this build (USE_Z3=ON).
648+
*
649+
* \return true if the real Z3 prover is available, false for the stub.
650+
*/
651+
TVM_DLL bool IsEnabled() const;
652+
647653
/*!
648654
* \brief Whether can we prove expr is always true.
649655
*
@@ -656,10 +662,9 @@ class Z3Prover {
656662
* \brief Update the internal state to enter constraint.
657663
*
658664
* \param constraint A constraint expression.
659-
* \param is_assume Whether the constraint comes from an assumption.
660665
* \return an exit function that must be called to cleanup the constraint can be nullptr.
661666
*/
662-
std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);
667+
std::function<void()> EnterConstraint(const PrimExpr& constraint);
663668

664669
/*!
665670
* \brief Get the SMTLIB2 representation of the current context.
@@ -886,8 +891,6 @@ class TVM_DLL Analyzer {
886891
* \note Analyzer will call into sub-analyzers to get the result.
887892
*/
888893
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
889-
890-
std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);
891894
};
892895

893896
} // namespace arith

python/tvm/_version.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# file generated by vcs-versioning
2+
# don't change, don't track in version control
3+
from __future__ import annotations
4+
5+
__all__ = [
6+
"__version__",
7+
"__version_tuple__",
8+
"version",
9+
"version_tuple",
10+
"__commit_id__",
11+
"commit_id",
12+
]
13+
14+
version: str
15+
__version__: str
16+
__version_tuple__: tuple[int | str, ...]
17+
version_tuple: tuple[int | str, ...]
18+
commit_id: str | None
19+
__commit_id__: str | None
20+
21+
__version__ = version = '0.25.dev100'
22+
__version_tuple__ = version_tuple = (0, 25, 'dev100')
23+
24+
__commit_id__ = commit_id = 'g35152f312'

python/tvm/arith/analyzer.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,51 +123,97 @@ def __init__(self):
123123
self._enter_constraint_context = _mod("enter_constraint_context")
124124
self._can_prove_equal = _mod("can_prove_equal")
125125
self._can_prove = _mod("can_prove")
126+
self._is_z3_enabled = _mod("is_z3_enabled")
126127
self._get_smtlib2 = _mod("get_smtlib2")
127128
self._set_z3_timeout_ms = _mod("set_z3_timeout_ms")
128129
self._set_z3_rlimit = _mod("set_z3_rlimit")
129130
self._get_z3_stats = _mod("get_z3_stats")
130131
self._get_enabled_extensions = _mod("get_enabled_extensions")
131132
self._set_enabled_extensions = _mod("set_enabled_extensions")
132133

133-
def get_smtlib2(self, expr: tirx.PrimExpr = None) -> str:
134+
@property
135+
def is_z3_enabled(self) -> bool:
136+
"""Whether this build includes the Z3 backend (``USE_Z3=ON``).
137+
138+
The Z3-specific methods (:py:meth:`get_smtlib2`, :py:meth:`get_z3_stats`,
139+
:py:meth:`set_z3_timeout_ms`, :py:meth:`set_z3_rlimit`) only work when
140+
this is ``True``.
141+
"""
142+
return bool(self._is_z3_enabled())
143+
144+
def _check_z3_enabled(self) -> None:
145+
if not self.is_z3_enabled:
146+
raise RuntimeError(
147+
"The Z3 backend is not available in this build. "
148+
"Rebuild TVM with USE_Z3=ON to use Z3-specific Analyzer APIs."
149+
)
150+
151+
def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str:
134152
"""Get the current Z3 problem in SMT-LIB2 format.
135153
154+
Raises
155+
------
156+
RuntimeError
157+
If TVM was built without Z3 (``USE_Z3=OFF``), since there is no
158+
solver state to export. Use :py:attr:`is_z3_enabled` to check first.
159+
136160
Parameters
137161
----------
138162
expr : Optional[PrimExpr]
139163
The expression to prove. If provided, its negation is added to the problem.
140164
"""
165+
self._check_z3_enabled()
141166
return self._get_smtlib2(expr)
142167

143168
def set_z3_timeout_ms(self, timeout_ms: int) -> None:
144169
"""Set Z3 timeout in milliseconds.
145170
171+
Raises
172+
------
173+
RuntimeError
174+
If TVM was built without Z3 (``USE_Z3=OFF``).
175+
146176
Parameters
147177
----------
148178
timeout_ms : int
149179
The timeout in milliseconds.
150180
"""
181+
self._check_z3_enabled()
151182
self._set_z3_timeout_ms(timeout_ms)
152183

153184
def set_z3_rlimit(self, rlimit: int) -> None:
154185
"""Set Z3 resource limit.
155186
187+
The resource limit gives deterministic solver budgeting (unlike a wall
188+
clock timeout). A value of ``0`` disables the limit.
189+
190+
Raises
191+
------
192+
RuntimeError
193+
If TVM was built without Z3 (``USE_Z3=OFF``).
194+
156195
Parameters
157196
----------
158197
rlimit : int
159198
The resource limit.
160199
"""
200+
self._check_z3_enabled()
161201
self._set_z3_rlimit(rlimit)
162202

163203
def get_z3_stats(self) -> str:
164204
"""Get Z3 solver statistics.
165205
206+
Raises
207+
------
208+
RuntimeError
209+
If TVM was built without Z3 (``USE_Z3=OFF``).
210+
166211
Returns
167212
-------
168213
stats : str
169214
The Z3 statistics.
170215
"""
216+
self._check_z3_enabled()
171217
return self._get_z3_stats()
172218

173219
def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound:
@@ -301,7 +347,9 @@ def can_prove(
301347
The expression.
302348
303349
strength: ProofStrength
304-
The proof strength
350+
The proof strength. When TVM is built with Z3 (``USE_Z3=ON``), the
351+
optional Z3 fallback is only consulted at ``SYMBOLIC_BOUND`` or
352+
higher, after the native analyzers fail to prove the predicate.
305353
306354
Returns
307355
-------

src/arith/analyzer.cc

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,10 @@ void ConstraintContext::EnterWithScope() {
131131
// entering the scope.
132132
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
133133
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
134-
recovery_functions_.push_back(
135-
analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_));
134+
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
136135
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
137136
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
138-
recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_));
137+
recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_));
139138
}
140139

141140
void ConstraintContext::ExitWithScope() {
@@ -235,30 +234,15 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
235234
}
236235
}
237236

238-
if (z3_prover.CanProve(simplified)) {
237+
// Z3 is an expensive best-effort fallback. Gate it behind the higher
238+
// kSymbolicBound strength so the common kDefault path (including deeply
239+
// recursive internal CanProve calls) never pays the prover cost.
240+
if (strength >= ProofStrength::kSymbolicBound && z3_prover.CanProve(simplified)) {
239241
return true;
240242
}
241243
return false;
242244
}
243245

244-
std::function<void()> Analyzer::EnterConstraint(const PrimExpr& constraint, bool is_assume) {
245-
std::vector<std::function<void()>> recovery_functions;
246-
recovery_functions.push_back(this->const_int_bound.EnterConstraint(constraint));
247-
recovery_functions.push_back(this->modular_set.EnterConstraint(constraint));
248-
recovery_functions.push_back(this->rewrite_simplify.EnterConstraint(constraint, is_assume));
249-
recovery_functions.push_back(this->int_set.EnterConstraint(constraint));
250-
recovery_functions.push_back(this->transitive_comparisons.EnterConstraint(constraint));
251-
recovery_functions.push_back(this->z3_prover.EnterConstraint(constraint, is_assume));
252-
return [recovery_functions]() {
253-
for (auto it = recovery_functions.rbegin(); it != recovery_functions.rend(); ++it) {
254-
auto& func = *it;
255-
if (func) {
256-
func();
257-
}
258-
}
259-
};
260-
}
261-
262246
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
263247
PrimExpr res = expr;
264248

@@ -371,6 +355,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
371355
self->rewrite_simplify.SetEnabledExtensions(
372356
static_cast<RewriteSimplifier::Extension>(flags));
373357
});
358+
} else if (name == "is_z3_enabled") {
359+
return ffi::Function(
360+
[self](ffi::PackedArgs args, ffi::Any* ret) { *ret = self->z3_prover.IsEnabled(); });
374361
} else if (name == "get_smtlib2") {
375362
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
376363
auto expr = args[0].cast<ffi::Optional<PrimExpr>>();

src/arith/rewrite_simplify.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
526526
return ret;
527527
}
528528

529-
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint,
530-
bool is_assume) {
529+
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
531530
size_t old_literal_size = literal_constraints_.size();
532531
// we will compare the already simplified result with the constraint,
533532
// so simplify the constraint as well
534533
PrimExpr new_constraint = operator()(constraint);
535534
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) {
536-
if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) {
535+
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
537536
literal_constraints_.push_back(subconstraint);
538537
PrimExpr negation;
539538
if (subconstraint.dtype().is_bool()) {
@@ -2441,9 +2440,8 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_
24412440
impl_->Update(var, info, allow_override);
24422441
}
24432442

2444-
std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint,
2445-
bool is_assume) {
2446-
return impl_->EnterConstraint(constraint, is_assume);
2443+
std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
2444+
return impl_->EnterConstraint(constraint);
24472445
}
24482446

24492447
void RewriteSimplifier::SetEnabledExtensions(Extension flags) {

0 commit comments

Comments
 (0)