Skip to content

Commit aa47b06

Browse files
Simplify algebraic identities involving commutative operators
Recognize commutativity (a*b == b*a), distributivity (a*(b+c) == a*b+a*c), and associativity ((a*b)*c == a*(b*c)) in the expression simplifier and reduce them to true/false before bit-blasting. This applies to all commutative operators: mult, plus, bitand, bitor, bitxor. These simplifications make verification of algebraic properties instant at any bitwidth, matching the word-level reasoning that SMT solvers like Z3 perform internally. Co-authored-by: Kiro <kiro-agent@users.noreply.github.com>
1 parent 6df204e commit aa47b06

2 files changed

Lines changed: 225 additions & 0 deletions

File tree

src/util/simplify_expr_int.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,153 @@ simplify_exprt::simplify_inequality(const binary_relation_exprt &expr)
13621362
if(tmp0.type() != tmp1.type())
13631363
return unchanged(expr);
13641364

1365+
// Simplify equalities/inequalities involving commutative and
1366+
// associative operators.
1367+
if(expr.id() == ID_equal || expr.id() == ID_notequal)
1368+
{
1369+
auto is_commutative = [](const irep_idt &id)
1370+
{
1371+
return id == ID_mult || id == ID_plus || id == ID_bitand ||
1372+
id == ID_bitor || id == ID_bitxor;
1373+
};
1374+
1375+
// Commutativity: a op b == b op a
1376+
if(
1377+
tmp0.id() == tmp1.id() && tmp0.operands().size() == 2 &&
1378+
tmp1.operands().size() == 2 && is_commutative(tmp0.id()))
1379+
{
1380+
if(
1381+
tmp0.operands()[0] == tmp1.operands()[1] &&
1382+
tmp0.operands()[1] == tmp1.operands()[0])
1383+
{
1384+
if(expr.id() == ID_equal)
1385+
return true_exprt();
1386+
else
1387+
return false_exprt();
1388+
}
1389+
}
1390+
1391+
// Distributivity: a * (b + c) == a * b + a * c (and variants)
1392+
auto distribute_mult = [](const exprt &e) -> std::optional<exprt>
1393+
{
1394+
if(e.id() != ID_mult || e.operands().size() != 2)
1395+
return {};
1396+
for(int i = 0; i < 2; ++i)
1397+
{
1398+
const exprt &factor = e.operands()[i];
1399+
const exprt &sum = e.operands()[1 - i];
1400+
if(sum.id() == ID_plus && sum.operands().size() == 2)
1401+
{
1402+
mult_exprt prod0(factor, sum.operands()[0]);
1403+
prod0.type() = e.type();
1404+
mult_exprt prod1(factor, sum.operands()[1]);
1405+
prod1.type() = e.type();
1406+
plus_exprt result(std::move(prod0), std::move(prod1));
1407+
result.type() = e.type();
1408+
return std::move(result);
1409+
}
1410+
}
1411+
return {};
1412+
};
1413+
1414+
auto prod_equal = [](const exprt &p, const exprt &q)
1415+
{
1416+
if(p == q)
1417+
return true;
1418+
if(
1419+
p.id() == ID_mult && q.id() == ID_mult && p.operands().size() == 2 &&
1420+
q.operands().size() == 2)
1421+
{
1422+
return p.operands()[0] == q.operands()[1] &&
1423+
p.operands()[1] == q.operands()[0];
1424+
}
1425+
return false;
1426+
};
1427+
1428+
auto deep_comm_equal = [&prod_equal](const exprt &a, const exprt &b)
1429+
{
1430+
if(a == b)
1431+
return true;
1432+
if(
1433+
a.id() == ID_plus && b.id() == ID_plus && a.operands().size() == 2 &&
1434+
b.operands().size() == 2)
1435+
{
1436+
return (prod_equal(a.operands()[0], b.operands()[0]) &&
1437+
prod_equal(a.operands()[1], b.operands()[1])) ||
1438+
(prod_equal(a.operands()[0], b.operands()[1]) &&
1439+
prod_equal(a.operands()[1], b.operands()[0]));
1440+
}
1441+
return false;
1442+
};
1443+
1444+
{
1445+
auto expanded0 = distribute_mult(tmp0);
1446+
auto expanded1 = distribute_mult(tmp1);
1447+
bool dist_equal = false;
1448+
if(expanded0.has_value() && deep_comm_equal(*expanded0, tmp1))
1449+
dist_equal = true;
1450+
else if(expanded1.has_value() && deep_comm_equal(tmp0, *expanded1))
1451+
dist_equal = true;
1452+
else if(
1453+
expanded0.has_value() && expanded1.has_value() &&
1454+
deep_comm_equal(*expanded0, *expanded1))
1455+
dist_equal = true;
1456+
1457+
if(dist_equal)
1458+
{
1459+
if(expr.id() == ID_equal)
1460+
return true_exprt();
1461+
else
1462+
return false_exprt();
1463+
}
1464+
}
1465+
1466+
// Associativity: flatten nested applications of the same
1467+
// associative+commutative operator and compare sorted leaf multisets.
1468+
// E.g., (a*b)*c == a*(b*c) both flatten to {a, b, c}.
1469+
if(
1470+
tmp0.operands().size() == 2 && tmp1.operands().size() == 2 &&
1471+
tmp0.id() == tmp1.id() && is_commutative(tmp0.id()))
1472+
{
1473+
const irep_idt &op_id = tmp0.id();
1474+
auto flatten = [&op_id](const exprt &e, std::vector<exprt> &leaves)
1475+
{
1476+
std::vector<const exprt *> worklist = {&e};
1477+
while(!worklist.empty())
1478+
{
1479+
const exprt *cur = worklist.back();
1480+
worklist.pop_back();
1481+
if(cur->id() == op_id && cur->operands().size() == 2)
1482+
{
1483+
worklist.push_back(&cur->operands()[0]);
1484+
worklist.push_back(&cur->operands()[1]);
1485+
}
1486+
else
1487+
leaves.push_back(*cur);
1488+
}
1489+
};
1490+
1491+
std::vector<exprt> leaves0, leaves1;
1492+
flatten(tmp0, leaves0);
1493+
flatten(tmp1, leaves1);
1494+
1495+
if(
1496+
leaves0.size() == leaves1.size() && leaves0.size() >= 2 &&
1497+
leaves0.size() <= 8)
1498+
{
1499+
std::sort(leaves0.begin(), leaves0.end());
1500+
std::sort(leaves1.begin(), leaves1.end());
1501+
if(leaves0 == leaves1)
1502+
{
1503+
if(expr.id() == ID_equal)
1504+
return true_exprt();
1505+
else
1506+
return false_exprt();
1507+
}
1508+
}
1509+
}
1510+
}
1511+
13651512
// if rhs is ID_if (and lhs is not), swap operands for == and !=
13661513
if((expr.id()==ID_equal || expr.id()==ID_notequal) &&
13671514
tmp0.id()!=ID_if &&

unit/util/simplify_expr.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,81 @@ TEST_CASE("Simplify quantifier", "[core][util]")
652652
REQUIRE(simplify_expr(forall_exprt{a, true_exprt{}}, ns) == true_exprt{});
653653
}
654654
}
655+
656+
TEST_CASE(
657+
"Simplify algebraic identities over commutative operators",
658+
"[core][util]")
659+
{
660+
const symbol_tablet symbol_table;
661+
const namespacet ns{symbol_table};
662+
663+
const unsignedbv_typet u32{32};
664+
const symbol_exprt a{"a", u32};
665+
const symbol_exprt b{"b", u32};
666+
const symbol_exprt c{"c", u32};
667+
668+
SECTION("Commutativity: a * b == b * a")
669+
{
670+
const equal_exprt eq{mult_exprt{a, b}, mult_exprt{b, a}};
671+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
672+
}
673+
674+
SECTION("Commutativity: a + b == b + a")
675+
{
676+
const equal_exprt eq{plus_exprt{a, b}, plus_exprt{b, a}};
677+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
678+
}
679+
680+
SECTION("Commutativity: a & b != b & a is false")
681+
{
682+
const notequal_exprt neq{bitand_exprt{a, b}, bitand_exprt{b, a}};
683+
REQUIRE(simplify_expr(neq, ns) == false_exprt{});
684+
}
685+
686+
SECTION("Distributivity: a * (b + c) == a * b + a * c")
687+
{
688+
const mult_exprt lhs{a, plus_exprt{b, c}};
689+
const plus_exprt rhs{mult_exprt{a, b}, mult_exprt{a, c}};
690+
const equal_exprt eq{lhs, rhs};
691+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
692+
}
693+
694+
SECTION("Distributivity: (b + c) * a == a * c + a * b")
695+
{
696+
const mult_exprt lhs{plus_exprt{b, c}, a};
697+
const plus_exprt rhs{mult_exprt{a, c}, mult_exprt{a, b}};
698+
const equal_exprt eq{lhs, rhs};
699+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
700+
}
701+
702+
SECTION("Associativity: (a * b) * c == a * (b * c)")
703+
{
704+
const mult_exprt lhs{mult_exprt{a, b}, c};
705+
const mult_exprt rhs{a, mult_exprt{b, c}};
706+
const equal_exprt eq{lhs, rhs};
707+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
708+
}
709+
710+
SECTION("Associativity: (a + b) + c == a + (b + c)")
711+
{
712+
const plus_exprt lhs{plus_exprt{a, b}, c};
713+
const plus_exprt rhs{a, plus_exprt{b, c}};
714+
const equal_exprt eq{lhs, rhs};
715+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
716+
}
717+
718+
SECTION("Associativity+Commutativity: (a * b) * c == c * (b * a)")
719+
{
720+
const mult_exprt lhs{mult_exprt{a, b}, c};
721+
const mult_exprt rhs{c, mult_exprt{b, a}};
722+
const equal_exprt eq{lhs, rhs};
723+
REQUIRE(simplify_expr(eq, ns) == true_exprt{});
724+
}
725+
726+
SECTION("Non-equal expressions are not simplified")
727+
{
728+
const equal_exprt eq{mult_exprt{a, b}, mult_exprt{a, c}};
729+
// Should NOT simplify to true (a*b != a*c in general)
730+
REQUIRE(simplify_expr(eq, ns) != true_exprt{});
731+
}
732+
}

0 commit comments

Comments
 (0)