Skip to content

Commit 7032a3e

Browse files
committed
feat: define new types automatically for pure tasks
1 parent 7a52e9d commit 7032a3e

7 files changed

Lines changed: 235 additions & 22 deletions

File tree

src/tfbench/ghc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def ghc_prove_equiv(code: str) -> Result[None, str]:
5050
type Char_ = Char
5151
type Float_ = Float
5252
type Double_ = Double
53+
data Natural = Natural
5354
5455
$new_types
5556

src/tfbench/hs_parser/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
get_type_vars,
77
get_type_constraints,
88
)
9+
from .extractor import TypeExtractor
910

1011
__all__ = [
1112
"AST",
@@ -16,4 +17,5 @@
1617
"to_type_node",
1718
"get_type_vars",
1819
"get_type_constraints",
20+
"TypeExtractor",
1921
]

src/tfbench/hs_parser/ast_util.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,6 @@ def get_fn_name(self, node: Node) -> Maybe[str]:
9999
return Nothing
100100
return Some(fn_name.strip())
101101

102-
def get_fn_docstring(self, node: Node) -> Maybe[str]:
103-
"""
104-
Retrieves the docstring associated with a function node.
105-
106-
Args:
107-
node (Node): The AST node representing a function.
108-
109-
Returns:
110-
Maybe[str]: A Maybe containing the docstring if found, or Nothing otherwise.
111-
"""
112-
# todo: implement docstring finder
113-
raise NotImplementedError
114-
115102
def func2src(self, func: HaskellFunction) -> tuple[str, str]:
116103
"""
117104
Converts a `HaskellFunction` object into its corresponding type signature and code source.

src/tfbench/hs_parser/extractor.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from collections import defaultdict, Counter
2+
from dataclasses import dataclass
3+
from tree_sitter import Node
4+
from .ast_util import AST
5+
6+
7+
class TypeExtractor(AST):
8+
"""Static analyzer for Haskell type signatures.
9+
NOTE: this analyzer works on the body of a type signature only,
10+
i.e. the part after the `=>` symbol if it has constraints,
11+
or otherwise after the `::` symbol.
12+
The constraints (if any) are handled in other modules.
13+
"""
14+
15+
def __init__(self, code: str):
16+
super().__init__(code)
17+
self.constructors: dict[str, Counter] = defaultdict(Counter)
18+
self.names: set[str] = set()
19+
20+
self._analysis_types()
21+
22+
@property
23+
def type_constructors(self) -> dict[str, int]:
24+
"""Get a mapping of type constructor names to their maximum observed arity (i.e. number of parameters)."""
25+
return {k: max(v.keys()) for k, v in self.constructors.items()}
26+
27+
def _analysis_types(self):
28+
"""analysis types in the function signature to fill out self.constructors and self.names"""
29+
sigs = self.get_all_nodes_of_type(self.root, "signature")
30+
functions = self.get_all_nodes_of_type(sigs[0], "function")
31+
if len(functions) > 0:
32+
self._visit(functions[0])
33+
34+
def _collect_from_tuple(self, node: Node):
35+
# record tuple arity if you care: arity = count of element children
36+
# then continue walking children
37+
for ch in node.named_children:
38+
self._visit(ch)
39+
40+
def _visit(self, n: Node):
41+
t = n.type
42+
43+
if t == "apply":
44+
# Count this application chain once, at the top-most 'apply' only.
45+
parent = n.parent
46+
if not (
47+
parent
48+
and parent.type == "apply"
49+
and parent.child_by_field_name("constructor") is n
50+
):
51+
apply_chain = _peel_apply_chain(n)
52+
ctor_name = self.get_src_from_node(apply_chain.constructor)
53+
self.constructors[ctor_name][apply_chain.arity] += 1
54+
# Recurse into children so we also catch nested names/applications.
55+
for ch in n.named_children:
56+
self._visit(ch)
57+
return
58+
59+
if t == "constructor":
60+
# Zero-arity constructor occurrence (e.g., `Int`) not part of an apply
61+
parent = n.parent
62+
if not (
63+
parent
64+
and parent.type == "apply"
65+
and parent.child_by_field_name("constructor") is n
66+
):
67+
name_node = n.child_by_field_name("name") or (
68+
n.named_children[0] if n.named_children else None
69+
)
70+
if name_node:
71+
constructor_name = self.get_src_from_node(name_node)
72+
self.constructors[constructor_name][0] += 1
73+
# still walk inside
74+
for ch in n.named_children:
75+
self._visit(ch)
76+
return
77+
78+
if t == "tuple":
79+
self._collect_from_tuple(n)
80+
return
81+
82+
if t == "name":
83+
# Treat as a plain type variable/name when not under a constructor role.
84+
p = n.parent
85+
# If its parent is 'constructor', it's part of a constructor; skip here.
86+
if p is None or p.type != "constructor":
87+
self.names.add(self.get_src_from_node(n))
88+
return
89+
90+
# default: recurse
91+
for ch in n.named_children:
92+
self._visit(ch)
93+
94+
95+
@dataclass
96+
class TypeApplyChain:
97+
constructor: Node
98+
arity: int
99+
arguments: list[Node]
100+
101+
102+
def _peel_apply_chain(node: Node) -> TypeApplyChain:
103+
"""
104+
Given an (apply ...) subtree, walk left through nested apply nodes to
105+
find the root constructor name and count how many arguments were applied.
106+
# Returns (arity, arg_nodes_list, constructor_node).
107+
"""
108+
args = []
109+
arity = 0
110+
cur = node
111+
while cur.type == "apply":
112+
arity += 1
113+
arg = cur.child_by_field_name("argument")
114+
if arg is not None:
115+
args.append(arg)
116+
# could be 'constructor' or another 'apply'
117+
next_level = cur.child_by_field_name("constructor")
118+
if not next_level:
119+
break
120+
cur = next_level
121+
122+
# now cur is either a 'constructor' node or a 'name' (rare)
123+
ctor_node = cur
124+
return TypeApplyChain(constructor=ctor_node, arity=arity, arguments=args)

src/tfbench/type_def.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .common import BenchmarkTask
44
from .hs_parser import AST, get_type_constraints
5+
from .hs_parser.extractor import TypeExtractor
56

67

78
def _is_type(code: str, type_name: str) -> bool:
@@ -49,8 +50,6 @@ def is_type_defined(type_name: str, type_defs: list[str]) -> bool:
4950
def get_type_defs(task: BenchmarkTask) -> list[str]:
5051
"""Get Haskell type definitions from a BenchmarkTask"""
5152
existing_defs = lfilter(is_type_def, task.dependencies)
52-
ast = AST(task.signature)
53-
sig = ast.get_all_nodes_of_type(ast.root, "signature")[0]
5453

5554
if "=>" in task.signature:
5655
constrains = get_type_constraints(task.signature)
@@ -60,14 +59,16 @@ def get_type_defs(task: BenchmarkTask) -> list[str]:
6059
continue
6160
existing_defs.append(def_new_type_class(ty_class, ty_vars))
6261

63-
for node in ast.get_all_nodes_of_type(sig, "name"):
64-
ty = ast.get_src_from_node(node)
65-
if is_type_defined(ty, existing_defs):
62+
extractor = TypeExtractor(task.signature)
63+
for ctor_name, arity in extractor.type_constructors.items():
64+
if is_type_defined(ctor_name, existing_defs):
6665
continue
66+
type_vars = [f"t{i}" for i in range(arity)]
67+
existing_defs.append(def_new_type_constructor(ctor_name, type_vars))
6768

68-
np = node.parent
69-
assert np is not None
70-
if np.type == "function": # data type
71-
existing_defs.append(def_new_type(ty))
69+
for type_name in extractor.names:
70+
if is_type_defined(type_name, existing_defs):
71+
continue
72+
existing_defs.append(def_new_type(type_name))
7273

7374
return list(existing_defs)

tests/test_eval_diff.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from os.path import abspath, dirname, basename, join as pjoin
2+
import os
3+
from itertools import starmap
4+
from multiprocessing import Pool
5+
6+
import pytest
7+
import fire
8+
from orjsonl import orjsonl
9+
from tqdm import tqdm
10+
from tfbench import (
11+
analysis_multi_runs,
12+
load_tfb_from_hf,
13+
load_gen_results_jsonl,
14+
prover_evaluate,
15+
)
16+
from tfbench.ghc import get_prover
17+
from tfbench.evaluation import evaluate_one_task, prove_one_task
18+
from tfbench.common import task2md
19+
from tfbench.type_def import get_type_defs
20+
from tfbench.postprocessing import postprocess, TASK_STRATEGIES, RESPONSE_STRATEGIES
21+
22+
23+
def diff_one_file(file_path: str, split: str):
24+
tasks = load_tfb_from_hf(split)
25+
answers = load_gen_results_jsonl(abspath(file_path))
26+
27+
old_eval = starmap(evaluate_one_task, zip(tasks, answers))
28+
with Pool() as pool:
29+
new_eval = pool.starmap(
30+
prove_one_task, zip(tasks, answers, [split == "pure"] * len(tasks))
31+
)
32+
33+
for t, a, o, n in zip(tasks, answers, old_eval, new_eval):
34+
if a is None:
35+
continue
36+
# if o:
37+
# assert n, "both evaluations should return a result"
38+
if o and not n:
39+
print(task2md(t))
40+
defs = get_type_defs(t)
41+
42+
predicted_body = postprocess(a.answer, RESPONSE_STRATEGIES).strip()
43+
predicted = f"f :: {predicted_body}"
44+
print(get_prover(t.signature, predicted, defs).unwrap())
45+
assert False
46+
47+
48+
def test_diff_recorded():
49+
"""different test evaluation function with recorded results
50+
Since the new prover evaluation fixes the false negative issue,
51+
we assume if an answer is determined as correct by the old evaluation,
52+
it should also be correct by the new evaluation.
53+
"""
54+
55+
result_path = abspath("results")
56+
# skip the test if there are not recorded results
57+
if not os.path.exists(result_path):
58+
pytest.skip("No recorded results found, skip the test.")
59+
60+
# walk the result directory to find all jsonl files
61+
for b, _, f in os.walk(result_path):
62+
for file in f:
63+
if file.endswith(".jsonl"):
64+
file_path = pjoin(b, file)
65+
split = basename(b)
66+
print(f"Diffing {file_path} ...")
67+
diff_one_file(file_path, split)
68+
69+
70+
if __name__ == "__main__":
71+
fire.Fire(test_diff_recorded)

tests/test_extractor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from tfbench.hs_parser import TypeExtractor
2+
3+
4+
def test_real_cases():
5+
code = "f:: T1 t1 => t1"
6+
et = TypeExtractor(code)
7+
assert not et.type_constructors
8+
9+
code = "f:: T1 t1 => T2 -> t1"
10+
et = TypeExtractor(code)
11+
assert not et.type_constructors
12+
assert et.names == {"T2"}
13+
14+
code = "f:: T1 t1 => T2 T3 -> t1"
15+
et = TypeExtractor(code)
16+
assert et.type_constructors == {"T2": 1}
17+
assert et.names == {"T2", "T3"}
18+
19+
code = "f:: T1 -> T2 T3 -> Either T1 T3 -> (T1, T3, T2 T3)"
20+
et = TypeExtractor(code)
21+
assert et.type_constructors == {"T2": 1, "Either": 2}
22+
assert et.names == {"T1", "T2", "T3", "Either"}
23+
24+
code = "g:: Ord a => Int -> Either String a -> T3 T1 T2 T4"
25+
et = TypeExtractor(code)
26+
assert et.type_constructors == {"Either": 2, "T3": 3}
27+
assert et.names == {"Int", "String", "T1", "T2", "T3", "T4", "Either"}

0 commit comments

Comments
 (0)