Skip to content

Commit 57b2153

Browse files
authored
Merge pull request #408 from posit-dev/feat-ai-validation-copilot
feat: additional AI validation functionality
2 parents e45f0c9 + 7a6c1fd commit 57b2153

15 files changed

Lines changed: 3289 additions & 141 deletions

great-docs.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ reference:
131131
members: false
132132
- name: DraftValidation
133133
members: false
134+
- name: EditValidation
135+
members: true
134136
- name: MissingSpec
135137
members: true
136138

@@ -260,6 +262,10 @@ reference:
260262
- Validate.get_step_report
261263
- Validate.get_json_report
262264
- Validate.get_dataframe_report
265+
- Validate.to_code
266+
- Validate.to_yaml
267+
- Validate.suggest_improvements
268+
- Validate.from_prompt
263269
- Validate.get_sundered_data
264270
- Validate.get_data_extracts
265271
- Validate.all_passed

pointblank/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pointblank.contract import Contract, Step
3535
from pointblank.datascan import DataScan, col_summary_tbl
3636
from pointblank.draft import DraftValidation
37+
from pointblank.edit import EditValidation
3738
from pointblank.field import (
3839
BoolField,
3940
DateField,
@@ -122,6 +123,7 @@
122123
"PipelineResult",
123124
"DataScan",
124125
"DraftValidation",
126+
"EditValidation",
125127
"MissingSpec",
126128
"col",
127129
"ref",

pointblank/_utils_ai.py

Lines changed: 267 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,32 @@ class _LLMConfig:
4545
verify_ssl: bool = True
4646

4747

48+
_ROW_VALIDATION_SYSTEM_PROMPT = """You are a data validation assistant. Your task is to analyze rows of data and determine if they meet the specified validation criteria.
49+
50+
INSTRUCTIONS:
51+
- Analyze each row in the provided data
52+
- For each row, determine if it meets the validation criteria (True) or not (False)
53+
- Return ONLY a JSON array with validation results
54+
- Each result should have: {"index": <row_index>, "result": <true_or_false>}
55+
- Do not include any explanatory text, only the JSON array
56+
- The row_index should match the "_pb_row_index" field from the input data
57+
58+
EXAMPLE OUTPUT FORMAT:
59+
[
60+
{"index": 0, "result": true},
61+
{"index": 1, "result": false},
62+
{"index": 2, "result": true}
63+
]
64+
65+
If reference attachments (images or PDFs) are provided alongside the data, use them as context when evaluating each row."""
66+
67+
4868
def _create_chat_instance(
49-
provider: str, model_name: str, api_key: Optional[str] = None, verify_ssl: bool = True
69+
provider: str,
70+
model_name: str,
71+
api_key: Optional[str] = None,
72+
verify_ssl: bool = True,
73+
system_prompt: Optional[str] = None,
5074
):
5175
"""
5276
Create a chatlas chat instance for the specified provider.
@@ -61,6 +85,10 @@ def _create_chat_instance(
6185
Optional API key. If None, will be read from environment.
6286
verify_ssl
6387
Whether to verify SSL certificates when making requests. Defaults to True.
88+
system_prompt
89+
Optional system prompt to steer the model. If None, a row-validation system prompt is
90+
used (the default for the AI validation flow). Other flows (e.g., drafting/editing
91+
validation plans) supply their own prompt.
6492
6593
Returns
6694
-------
@@ -82,25 +110,9 @@ def _create_chat_instance(
82110
f"Supported providers: {', '.join(MODEL_PROVIDERS)}"
83111
)
84112

85-
# System prompt with role definition and instructions
86-
system_prompt = """You are a data validation assistant. Your task is to analyze rows of data and determine if they meet the specified validation criteria.
87-
88-
INSTRUCTIONS:
89-
- Analyze each row in the provided data
90-
- For each row, determine if it meets the validation criteria (True) or not (False)
91-
- Return ONLY a JSON array with validation results
92-
- Each result should have: {"index": <row_index>, "result": <true_or_false>}
93-
- Do not include any explanatory text, only the JSON array
94-
- The row_index should match the "_pb_row_index" field from the input data
95-
96-
EXAMPLE OUTPUT FORMAT:
97-
[
98-
{"index": 0, "result": true},
99-
{"index": 1, "result": false},
100-
{"index": 2, "result": true}
101-
]
102-
103-
If reference attachments (images or PDFs) are provided alongside the data, use them as context when evaluating each row."""
113+
# Use the row-validation system prompt unless a caller supplies its own
114+
if system_prompt is None:
115+
system_prompt = _ROW_VALIDATION_SYSTEM_PROMPT
104116

105117
# Create httpx client with SSL verification settings
106118
try:
@@ -989,3 +1001,238 @@ def validate_single_batch(
9891001
for i in range(batch["start_row"], batch["end_row"]):
9901002
default_results.append({"index": i, "result": False})
9911003
return default_results
1004+
1005+
1006+
# ============================================================================
1007+
# Validation-plan code helpers (shared by DraftValidation and EditValidation)
1008+
# ============================================================================
1009+
1010+
import ast as _ast
1011+
import difflib as _difflib
1012+
1013+
# Validation method names a generated or edited plan may legitimately call. Used by the
1014+
# syntax guardrail to catch hallucinated methods before a plan is handed back to the user.
1015+
_KNOWN_PLAN_METHODS = frozenset(
1016+
{
1017+
"col_vals_gt",
1018+
"col_vals_lt",
1019+
"col_vals_eq",
1020+
"col_vals_ne",
1021+
"col_vals_ge",
1022+
"col_vals_le",
1023+
"col_vals_between",
1024+
"col_vals_outside",
1025+
"col_vals_in_set",
1026+
"col_vals_not_in_set",
1027+
"col_vals_increasing",
1028+
"col_vals_decreasing",
1029+
"col_vals_null",
1030+
"col_vals_not_null",
1031+
"col_vals_regex",
1032+
"col_vals_within_spec",
1033+
"col_vals_expr",
1034+
"col_exists",
1035+
"col_pct_null",
1036+
"col_schema_match",
1037+
"col_count_match",
1038+
"row_count_match",
1039+
"rows_distinct",
1040+
"rows_complete",
1041+
"data_freshness",
1042+
"tbl_match",
1043+
"conjointly",
1044+
"specially",
1045+
"interrogate",
1046+
}
1047+
)
1048+
1049+
1050+
def _extract_code(text: str) -> str:
1051+
"""Extract Python source from a model response or fenced code block.
1052+
1053+
Strips a single ```python ... ``` (or bare ``` ... ```) fence if present; otherwise returns
1054+
the text unchanged (trimmed).
1055+
"""
1056+
stripped = text.strip()
1057+
if "```" not in stripped:
1058+
return stripped
1059+
1060+
fence_start = stripped.find("```")
1061+
after = stripped[fence_start + 3 :]
1062+
newline = after.find("\n")
1063+
if newline != -1:
1064+
first_line = after[:newline].strip()
1065+
if first_line.lower() in ("python", "py", ""):
1066+
after = after[newline + 1 :]
1067+
fence_end = after.find("```")
1068+
if fence_end != -1:
1069+
after = after[:fence_end]
1070+
return after.strip()
1071+
1072+
1073+
def _check_syntax(code: str) -> Tuple[bool, str]:
1074+
"""Parse `code` and flag unknown validation methods.
1075+
1076+
Returns `(ok, message)`. `ok` is False if the code fails to parse or calls a method that
1077+
looks like a validation method but is not in the known set; `message` describes the first
1078+
problem found.
1079+
"""
1080+
try:
1081+
tree = _ast.parse(code)
1082+
except SyntaxError as e:
1083+
return False, f"SyntaxError: {e.msg} (line {e.lineno})"
1084+
1085+
unknown: List[str] = []
1086+
for node in _ast.walk(tree):
1087+
if isinstance(node, _ast.Call) and isinstance(node.func, _ast.Attribute):
1088+
name = node.func.attr
1089+
if name.startswith(("col_", "row_", "rows_")) and name not in _KNOWN_PLAN_METHODS:
1090+
unknown.append(name)
1091+
if unknown:
1092+
return False, f"Unknown validation method(s): {', '.join(sorted(set(unknown)))}"
1093+
1094+
return True, ""
1095+
1096+
1097+
def _chain_base_is_validate(node: Any) -> bool:
1098+
"""Whether an attribute-call chain ultimately starts from a `Validate(...)` call."""
1099+
cur = node
1100+
# Descend through chained method calls (whose receiver is itself a call) to the base call
1101+
while (
1102+
isinstance(cur, _ast.Call)
1103+
and isinstance(cur.func, _ast.Attribute)
1104+
and isinstance(cur.func.value, _ast.Call)
1105+
):
1106+
cur = cur.func.value
1107+
if isinstance(cur, _ast.Call):
1108+
func = cur.func
1109+
if isinstance(func, _ast.Attribute):
1110+
return func.attr == "Validate"
1111+
if isinstance(func, _ast.Name):
1112+
return func.id == "Validate"
1113+
return False
1114+
1115+
1116+
def _chain_depth(node: Any) -> int:
1117+
depth = 0
1118+
cur = node
1119+
while (
1120+
isinstance(cur, _ast.Call)
1121+
and isinstance(cur.func, _ast.Attribute)
1122+
and isinstance(cur.func.value, _ast.Call)
1123+
):
1124+
depth += 1
1125+
cur = cur.func.value
1126+
return depth
1127+
1128+
1129+
def _extract_chain_steps(code: str) -> List[Dict[str, Any]]:
1130+
"""Parse a plan's code into an ordered list of step descriptors.
1131+
1132+
Each descriptor is `{"method": str, "text": str, "kwargs": {name: source}}`. The terminal
1133+
`.interrogate()` call, if present, is omitted. Returns an empty list if the code cannot be
1134+
parsed or contains no `Validate(...)` chain.
1135+
"""
1136+
try:
1137+
tree = _ast.parse(code)
1138+
except SyntaxError:
1139+
return []
1140+
1141+
# Locate the deepest (outermost) Validate(...) method chain in the module
1142+
best: Optional[Tuple[int, Any]] = None
1143+
for node in _ast.walk(tree):
1144+
if isinstance(node, _ast.Call) and _chain_base_is_validate(node):
1145+
depth = _chain_depth(node)
1146+
if best is None or depth > best[0]:
1147+
best = (depth, node)
1148+
if best is None:
1149+
return []
1150+
1151+
calls: List[Any] = []
1152+
cur = best[1]
1153+
# Collect chained method calls, stopping before the base `Validate(...)` constructor
1154+
while (
1155+
isinstance(cur, _ast.Call)
1156+
and isinstance(cur.func, _ast.Attribute)
1157+
and isinstance(cur.func.value, _ast.Call)
1158+
):
1159+
calls.append(cur)
1160+
cur = cur.func.value
1161+
calls.reverse()
1162+
1163+
steps: List[Dict[str, Any]] = []
1164+
for call in calls:
1165+
method = call.func.attr
1166+
if method == "interrogate":
1167+
continue
1168+
args = [_ast.unparse(a) for a in call.args]
1169+
kwargs: Dict[str, Any] = {}
1170+
for kw in call.keywords:
1171+
if kw.arg is not None:
1172+
kwargs[kw.arg] = _ast.unparse(kw.value)
1173+
rendered = ", ".join(args + [f"{k}={v}" for k, v in kwargs.items()])
1174+
steps.append({"method": method, "text": f"{method}({rendered})", "kwargs": kwargs})
1175+
return steps
1176+
1177+
1178+
def _diff_plan_steps(
1179+
old_steps: List[Dict[str, Any]], new_steps: List[Dict[str, Any]]
1180+
) -> List[Dict[str, Any]]:
1181+
"""Compute a structured, step-level diff between two parsed plans.
1182+
1183+
Returns a list of change records, each `{"action": "add"|"remove"|"modify", "method": str,
1184+
...}`. For "add"/"modify" the record includes `"new"` (the new step text); for
1185+
"remove"/"modify" it includes `"old"` (the old step text).
1186+
"""
1187+
old_texts = [s["text"] for s in old_steps]
1188+
new_texts = [s["text"] for s in new_steps]
1189+
matcher = _difflib.SequenceMatcher(a=old_texts, b=new_texts, autojunk=False)
1190+
1191+
changes: List[Dict[str, Any]] = []
1192+
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
1193+
if tag == "equal":
1194+
continue
1195+
if tag == "delete":
1196+
for k in range(i1, i2):
1197+
changes.append(
1198+
{"action": "remove", "method": old_steps[k]["method"], "old": old_texts[k]}
1199+
)
1200+
elif tag == "insert":
1201+
for k in range(j1, j2):
1202+
changes.append(
1203+
{"action": "add", "method": new_steps[k]["method"], "new": new_texts[k]}
1204+
)
1205+
else: # "replace": pair same-method steps as modifications, leftovers as add/remove
1206+
old_block = list(range(i1, i2))
1207+
new_block = list(range(j1, j2))
1208+
used_new: set[int] = set()
1209+
for oi in list(old_block):
1210+
match = next(
1211+
(
1212+
nj
1213+
for nj in new_block
1214+
if nj not in used_new and new_steps[nj]["method"] == old_steps[oi]["method"]
1215+
),
1216+
None,
1217+
)
1218+
if match is not None:
1219+
used_new.add(match)
1220+
old_block.remove(oi)
1221+
changes.append(
1222+
{
1223+
"action": "modify",
1224+
"method": old_steps[oi]["method"],
1225+
"old": old_texts[oi],
1226+
"new": new_texts[match],
1227+
}
1228+
)
1229+
for oi in old_block:
1230+
changes.append(
1231+
{"action": "remove", "method": old_steps[oi]["method"], "old": old_texts[oi]}
1232+
)
1233+
for nj in new_block:
1234+
if nj not in used_new:
1235+
changes.append(
1236+
{"action": "add", "method": new_steps[nj]["method"], "new": new_texts[nj]}
1237+
)
1238+
return changes

0 commit comments

Comments
 (0)