Skip to content

Commit 0d93a4d

Browse files
Support self-recursive records
Closes #5
1 parent 782c1eb commit 0d93a4d

7 files changed

Lines changed: 197 additions & 46 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
construct an AST that is not a DAG. (Note: you can of course still
1616
re-use nodes, as long as there are no loop in the DAG.)
1717

18+
* Add support for self-recursive records. This is a pre-cursor to full
19+
mutual recursion, but it is enough for TRLC.
20+
21+
* The `add_component` method of `Record` can take the record sort
22+
itself.
23+
24+
* A new expression `Record_Null_Check` can be used to check if a
25+
given term is equal to the null constructor. This can only be used
26+
on recursive records.
27+
1828
* Move to CVC 1.3.1.
1929

2030
* Add support for Python up to 3.14.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ features:
2424
[String](https://cvc5.github.io/docs-ci/docs-main/theories/strings.html)
2525
* Parametric sorts:
2626
[Sequences](https://cvc5.github.io/docs-ci/docs-main/theories/sequences.html)
27-
* Datatype sorts: Enumerations and Records
27+
* Datatype sorts: Enumerations and Records (including self-recursive records)
2828
* Uninterpreted functions
2929
* Quantifiers
3030
* Boolean expressions: not, and, or, xor, implication

pyvcg/driver/cvc5_api.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## ##
44
## Verification Condition Generator ##
55
## ##
6-
## Copyright (C) 2023, Florian Schanda ##
6+
## Copyright (C) 2023-2025, Florian Schanda ##
77
## ##
88
## This file is part of PyVCG. ##
99
## ##
@@ -55,11 +55,18 @@ def term_to_python(self, sort, term):
5555
return rv
5656
elif isinstance(sort, smt.Record):
5757
assert term.getKind() == cvc5.Kind.APPLY_CONSTRUCTOR
58-
rv = {}
59-
for idx, name in enumerate(sort.components, 1):
60-
rv[name] = self.term_to_python(sort.components[name],
61-
term[idx])
62-
return rv
58+
cons_name = term[0].getSymbol()
59+
assert isinstance(cons_name, str)
60+
assert cons_name.endswith("__cons") or cons_name.endswith("__null")
61+
if cons_name.endswith("__null"):
62+
assert sort.is_recursive
63+
return None
64+
else:
65+
rv = {}
66+
for idx, name in enumerate(sort.components, 1):
67+
rv[name] = self.term_to_python(sort.components[name],
68+
term[idx])
69+
return rv
6370
elif sort.name == "Bool":
6471
return term.getBooleanValue()
6572
elif sort.name == "Int":
@@ -258,9 +265,19 @@ def visit_record_declaration(self, node):
258265
ctor = self.solver.mkDatatypeConstructorDecl(node.sort.name +
259266
"__cons")
260267
for name, sort in node.sort.components.items():
261-
ctor.addSelector(name, sort.walk(self))
262-
263-
sort = self.solver.declareDatatype(node.sort.name, ctor)
268+
if sort is node.sort:
269+
ctor.addSelectorSelf(name)
270+
else:
271+
ctor.addSelector(name, sort.walk(self))
272+
273+
if node.sort.is_recursive:
274+
null_ctor = self.solver.mkDatatypeConstructorDecl(node.sort.name +
275+
"__null")
276+
sort = self.solver.declareDatatype(node.sort.name,
277+
null_ctor,
278+
ctor)
279+
else:
280+
sort = self.solver.declareDatatype(node.sort.name, ctor)
264281
self.record_mapping[node.sort] = sort
265282

266283
def visit_sort(self, node):
@@ -483,6 +500,17 @@ def visit_record_access(self, node, tr_record):
483500
s_selector.getTerm(),
484501
tr_record)
485502

503+
def visit_record_null_check(self, node, tr_record):
504+
assert isinstance(node, smt.Record_Null_Check)
505+
s_record_sort = self.record_mapping[node.record.sort]
506+
s_dt = s_record_sort.getDatatype()
507+
s_cons = s_dt.getConstructor(node.record.sort.name + "__null")
508+
return self.solver.mkTerm(
509+
cvc5.Kind.EQUAL,
510+
tr_record,
511+
self.solver.mkTerm(cvc5.Kind.APPLY_CONSTRUCTOR,
512+
s_cons.getTerm()))
513+
486514
def visit_function_application(self, node, tr_function, tr_args):
487515
assert isinstance(node, smt.Function_Application)
488516
assert isinstance(tr_args, list)

pyvcg/driver/cvc5_smtlib.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## ##
44
## Verification Condition Generator ##
55
## ##
6-
## Copyright (C) 2023, Florian Schanda ##
6+
## Copyright (C) 2023-2025, Florian Schanda ##
77
## ##
88
## This file is part of PyVCG. ##
99
## ##
@@ -338,16 +338,24 @@ def parse_enum(self):
338338

339339
def parse_record(self, typ):
340340
assert isinstance(typ, smt.Record)
341-
self.match("BRA")
342-
self.match("IDENTIFIER")
343-
if typ.name + "__cons" != self.ct.value: # pragma: no cover
344-
self.error("unexpected constructor %s (expected %s)" %
345-
(self.ct.value,
346-
typ.name + "__cons"))
347-
rv = {}
348-
for name, sort in typ.components.items():
349-
rv[name] = self.parse_value(sort)
350-
self.match("KET")
341+
if self.peek("BRA") or not typ.is_recursive:
342+
self.match("BRA")
343+
self.match("IDENTIFIER")
344+
if typ.name + "__cons" != self.ct.value: # pragma: no cover
345+
self.error("unexpected constructor %s (expected %s)" %
346+
(self.ct.value,
347+
typ.name + "__cons"))
348+
rv = {}
349+
for name, sort in typ.components.items():
350+
rv[name] = self.parse_value(sort)
351+
self.match("KET")
352+
else:
353+
self.match("IDENTIFIER")
354+
if typ.name + "__null" != self.ct.value: # pragma: no cover
355+
self.error("unexpected constructor %s (expected %s)" %
356+
(self.ct.value,
357+
typ.name + "__null"))
358+
rv = None
351359
return rv
352360

353361

@@ -375,13 +383,20 @@ def visit_record_declaration(self, node):
375383
self.records[node.sort.name] = node.sort
376384

377385
def solve(self):
378-
result = subprocess.run([self.binary,
379-
"--lang=smt2",
380-
"-"],
381-
input = self.instance,
382-
capture_output = True,
383-
check = True,
384-
encoding = "UTF-8")
386+
try:
387+
result = subprocess.run([self.binary,
388+
"--lang=smt2",
389+
"-"],
390+
input = self.instance,
391+
capture_output = True,
392+
check = True,
393+
encoding = "UTF-8")
394+
except subprocess.CalledProcessError as err: # pragma: no cover
395+
print(self.instance)
396+
print(err.stdout)
397+
print(err.stderr)
398+
raise err
399+
385400
lines = result.stdout.splitlines()
386401
status, tail = lines[0].strip(), lines[1:]
387402
assert status in ("sat", "unsat", "unknown"), \

pyvcg/driver/file_smtlib.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## ##
44
## Verification Condition Generator ##
55
## ##
6-
## Copyright (C) 2023, Florian Schanda ##
6+
## Copyright (C) 2023-2025, Florian Schanda ##
77
## ##
88
## This file is part of PyVCG. ##
99
## ##
@@ -172,12 +172,21 @@ def visit_enumeration_declaration(self, node):
172172
def visit_record_declaration(self, node):
173173
assert isinstance(node, smt.Record_Declaration)
174174
self.emit_comment(node.comment)
175-
self.lines.append("(declare-datatype %s ((%s" %
176-
(self.escape_name(node.sort.name),
177-
self.escape_name(node.sort.name + "__cons")))
175+
self.lines.append("(declare-datatype %s ("
176+
% self.escape_name(node.sort.name))
177+
if node.sort.is_recursive:
178+
self.lines.append(" (%s)" %
179+
self.escape_name(node.sort.name + "__null"))
180+
self.lines.append(" (%s" %
181+
self.escape_name(node.sort.name + "__cons"))
178182
for name, sort in node.sort.components.items():
179-
self.lines.append(" (%s %s)" % (self.escape_name(name),
180-
sort.walk(self)))
183+
if sort is node.sort:
184+
self.lines.append(" (%s %s)" %
185+
(self.escape_name(name),
186+
self.escape_name(node.sort.name)))
187+
else:
188+
self.lines.append(" (%s %s)" % (self.escape_name(name),
189+
sort.walk(self)))
181190
self.lines[-1] += ")))"
182191

183192
def visit_sort(self, node):
@@ -321,6 +330,12 @@ def visit_record_access(self, node, tr_record):
321330
assert isinstance(node, smt.Record_Access)
322331
return "(%s %s)" % (node.component, tr_record)
323332

333+
def visit_record_null_check(self, node, tr_record):
334+
assert isinstance(node, smt.Record_Null_Check)
335+
return "(= %s %s)" % (
336+
tr_record,
337+
self.escape_name(node.record.sort.name + "__null"))
338+
324339
def visit_function_application(self, node, tr_function, tr_args):
325340
assert isinstance(node, smt.Function_Application)
326341
assert isinstance(tr_args, list)

pyvcg/smt.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def visit_sequence_concatenation(self, node, tr_lhs, tr_rhs):
198198
def visit_record_access(self, node, tr_record):
199199
assert isinstance(node, Record_Access)
200200

201+
@abstractmethod
202+
def visit_record_null_check(self, node, tr_record):
203+
assert isinstance(node, Record_Null_Check)
204+
201205
@abstractmethod
202206
def visit_function_application(self, node, tr_function, tr_args):
203207
assert isinstance(node, Function_Application)
@@ -353,7 +357,8 @@ def visit_record(self, node):
353357
assert isinstance(node, Record)
354358
self.logics.add("datatypes")
355359
for sort in node.components.values():
356-
sort.walk(self)
360+
if sort is not node:
361+
sort.walk(self)
357362

358363
def visit_boolean_literal(self, node, tr_sort):
359364
assert isinstance(node, Boolean_Literal)
@@ -464,6 +469,10 @@ def visit_record_access(self, node, tr_record):
464469
assert isinstance(node, Record_Access)
465470
self.logics.add("datatypes")
466471

472+
def visit_record_null_check(self, node, tr_record):
473+
assert isinstance(node, Record_Null_Check)
474+
self.logics.add("datatypes")
475+
467476
def visit_function_application(self, node, tr_function, tr_args):
468477
assert isinstance(node, Function_Application)
469478
assert isinstance(tr_args, list)
@@ -763,23 +772,26 @@ def check_recursion_rec(self, visited):
763772
class Record(Sort):
764773
def __init__(self, name):
765774
super().__init__(name)
766-
self.components = {}
775+
self.components = {}
776+
self.is_recursive = False
767777

768778
def add_component(self, name, sort):
769779
assert isinstance(name, str)
770780
assert isinstance(sort, Sort)
771781
assert name not in self.components
772782

773783
self.components[name] = sort
784+
self.is_recursive |= sort is self
774785

775786
def walk(self, visitor):
776787
assert isinstance(visitor, Visitor)
777788
return visitor.visit_record(self)
778789

779790
def check_recursion_rec(self, visited):
780791
new_visited = self.check_recursion_enforce(visited)
781-
for component in self.components.values():
782-
component.check_recursion_rec(new_visited)
792+
for component_type in self.components.values():
793+
if component_type is not self:
794+
component_type.check_recursion_rec(new_visited)
783795

784796

785797
class Sequence_Sort(Parametric_Sort):
@@ -1364,6 +1376,25 @@ def check_recursion_rec(self, visited):
13641376
self.record.check_recursion_rec(new_visited)
13651377

13661378

1379+
class Record_Null_Check(Expression):
1380+
def __init__(self, record):
1381+
assert isinstance(record, Expression)
1382+
assert isinstance(record.sort, Record)
1383+
assert record.sort.is_recursive, "non-recursive records cannot be null"
1384+
super().__init__(BUILTIN_BOOLEAN)
1385+
self.record = record
1386+
1387+
def walk(self, visitor):
1388+
assert isinstance(visitor, Visitor)
1389+
return visitor.visit_record_null_check(self,
1390+
self.record.walk(visitor))
1391+
1392+
def check_recursion_rec(self, visited):
1393+
new_visited = self.check_recursion_enforce(visited)
1394+
self.sort.check_recursion_rec(new_visited)
1395+
self.record.check_recursion_rec(new_visited)
1396+
1397+
13671398
class Function_Application(Expression):
13681399
def __init__(self, function, *arguments):
13691400
assert isinstance(function, Function)

tests/testSmt.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -965,9 +965,10 @@ def test_Records(self):
965965
(set-logic QF_DTSLIA)
966966
(set-option :produce-models true)
967967
968-
(declare-datatype Kitten ((Kitten__cons
969-
(legs Int)
970-
(name String))))
968+
(declare-datatype Kitten (
969+
(Kitten__cons
970+
(legs Int)
971+
(name String))))
971972
(declare-const a Kitten)
972973
(assert (= (legs a) 4))
973974
(assert (= (name a) "fuzzy"))
@@ -979,12 +980,63 @@ def test_Records(self):
979980
self.assertValue("a", {"name": "fuzzy",
980981
"legs": 4})
981982

982-
def test_Recursive_Record(self):
983-
s_sort = smt.Record("MyType")
984-
s_sort.add_component("name", smt.BUILTIN_STRING)
985-
s_sort.add_component("link", s_sort)
983+
def test_Recursive_Tree(self):
984+
sort_a = smt.Record("a")
985+
sort_b = smt.Record("b")
986+
987+
sort_a.add_component("wibble", sort_b)
988+
sort_b.add_component("wobble", sort_a)
986989
with self.assertRaises(smt.Recursion):
987-
self.script.add_statement(smt.Record_Declaration(s_sort))
990+
self.script.add_statement(smt.Record_Declaration(sort_a))
991+
self.script.add_statement(smt.Record_Declaration(sort_b))
992+
993+
def test_Recursive_Record(self):
994+
s_sort = smt.Record("List")
995+
s_sort.add_component("value", smt.BUILTIN_INTEGER)
996+
s_sort.add_component("next", s_sort)
997+
self.script.add_statement(smt.Record_Declaration(s_sort))
998+
999+
sym_a = smt.Constant(s_sort, "a")
1000+
self.script.add_statement(
1001+
smt.Constant_Declaration(sym_a,
1002+
relevant=True))
1003+
sym_b = smt.Constant(s_sort, "b")
1004+
self.script.add_statement(
1005+
smt.Constant_Declaration(sym_b,
1006+
relevant=True))
1007+
1008+
self.script.add_statement(
1009+
smt.Assertion(
1010+
smt.Comparison("=",
1011+
smt.Record_Access(smt.Record_Access(sym_a,
1012+
"next"),
1013+
"value"),
1014+
smt.Integer_Literal(42))))
1015+
1016+
self.script.add_statement(
1017+
smt.Assertion(smt.Boolean_Negation(smt.Record_Null_Check(sym_b))))
1018+
1019+
self.assertResult(
1020+
"sat",
1021+
"""
1022+
(set-logic QF_DTLIA)
1023+
(set-option :produce-models true)
1024+
1025+
(declare-datatype List (
1026+
(List__null)
1027+
(List__cons
1028+
(value Int)
1029+
(next List))))
1030+
(declare-const a List)
1031+
(declare-const b List)
1032+
(assert (= (value (next a)) 42))
1033+
(assert (not (= b List__null)))
1034+
(check-sat)
1035+
(get-value (a))
1036+
(get-value (b))
1037+
(exit)
1038+
"""
1039+
)
9881040

9891041
def test_UF_No_Body(self):
9901042
s_par = smt.Bound_Variable(smt.BUILTIN_INTEGER, "x")

0 commit comments

Comments
 (0)