-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathfalsify_prml.py
More file actions
321 lines (268 loc) · 11.5 KB
/
falsify_prml.py
File metadata and controls
321 lines (268 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
#!/usr/bin/env python3
"""falsify — PRML v0.1 / v0.2 reference CLI (Python).
Commits an ML evaluation claim — metric, threshold, dataset hash, seed — to a
SHA-256 over the canonical manifest bytes *before* the run. Re-derivable by
anyone; edit the manifest after locking and the hash no longer matches.
Canonicalisation (PRML v0.1 §4): keys recursively sorted, block style, LF,
trailing whitespace stripped, exactly one trailing newline, UTF-8. This is the
same rule the Go / JS / Rust reference implementations use; all four produce
byte-identical canonical bytes on the 20 published conformance vectors.
Commands:
falsify lock <spec.yaml|spec.json> canonicalize, hash, write sidecar
falsify verify <spec> [--observed <v>] verify hash; if --observed, evaluate
falsify hash <spec> print the canonical SHA-256 only
falsify init <name> write a skeleton manifest
falsify test-vectors <vectors.json> run the conformance suite
falsify --version
Exit codes: 0 PASS · 3 TAMPERED (hash mismatch) · 10 FAIL (threshold) ·
2 bad input/spec · 11 guard (missing sidecar / lib).
Spec: https://spec.falsify.dev/v0.1
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import re
import sys
__version__ = "0.3.0"
EXIT_PASS = 0
EXIT_BAD = 2
EXIT_TAMPERED = 3
EXIT_FAIL = 10
EXIT_GUARD = 11
REQUIRED_FIELDS = [
"version", "claim_id", "created_at", "metric",
"comparator", "threshold", "dataset", "seed", "producer",
]
REQUIRED_DATASET = ["id", "hash"]
REQUIRED_PRODUCER = ["id"]
VALID_COMPARATORS = {">=", "<=", ">", "<", "=="}
_HEX64 = re.compile(r"^[0-9a-f]{64}$")
def _require_yaml():
try:
import yaml # noqa: F401
return yaml
except ImportError:
sys.stderr.write(
"YAML support requires PyYAML: pip install pyyaml. "
"Or pass a .json manifest.\n"
)
raise SystemExit(EXIT_GUARD)
# ─────────────────────────────────────────────────────────────────────────
# Canonicalisation — PRML v0.1 §4 (matches spec/test-vectors reference-target.py)
# ─────────────────────────────────────────────────────────────────────────
def canonicalize(manifest: dict) -> str:
yaml = _require_yaml()
canonical = yaml.safe_dump(
manifest,
default_flow_style=False,
sort_keys=True,
width=float("inf"),
allow_unicode=True,
)
return canonical.replace("\r\n", "\n").rstrip() + "\n"
def manifest_hash(manifest: dict) -> str:
return hashlib.sha256(canonicalize(manifest).encode("utf-8")).hexdigest()
# ─────────────────────────────────────────────────────────────────────────
# Loading + validation
# ─────────────────────────────────────────────────────────────────────────
def load_manifest(path: str) -> dict:
with open(path, "r", encoding="utf-8") as fh:
text = fh.read()
if path.endswith(".json"):
return json.loads(text)
yaml = _require_yaml()
return yaml.safe_load(text)
def validate_manifest(m: dict) -> list[str]:
errors: list[str] = []
if not isinstance(m, dict):
return ["manifest must be a mapping"]
for f in REQUIRED_FIELDS:
if f not in m:
errors.append(f"missing required field: {f}")
if m.get("version") not in ("prml/0.1", "prml/0.2"):
errors.append(f'version must be "prml/0.1" or "prml/0.2", got "{m.get("version")}"')
if not isinstance(m.get("threshold"), (int, float)) or isinstance(m.get("threshold"), bool):
errors.append("threshold must be a finite number")
if m.get("comparator") and m["comparator"] not in VALID_COMPARATORS:
errors.append("comparator must be one of " + ", ".join(sorted(VALID_COMPARATORS)))
ds = m.get("dataset")
if isinstance(ds, dict):
for f in REQUIRED_DATASET:
if f not in ds:
errors.append(f"missing required field: dataset.{f}")
if ds.get("hash") and not _HEX64.match(str(ds["hash"])):
errors.append("dataset.hash must be 64 lowercase hex chars")
prod = m.get("producer")
if isinstance(prod, dict):
for f in REQUIRED_PRODUCER:
if f not in prod:
errors.append(f"missing required field: producer.{f}")
return errors
def evaluate_predicate(observed: float, comparator: str, threshold: float) -> bool:
if comparator == ">=":
return observed >= threshold
if comparator == "<=":
return observed <= threshold
if comparator == ">":
return observed > threshold
if comparator == "<":
return observed < threshold
if comparator == "==":
return observed == threshold
raise ValueError(f"invalid comparator: {comparator}")
def _sidecar_path(spec_path: str) -> str:
return re.sub(r"\.[^.]+$", "", spec_path) + ".prml.sha256"
# ─────────────────────────────────────────────────────────────────────────
# Commands
# ─────────────────────────────────────────────────────────────────────────
def cmd_lock(args) -> int:
try:
m = load_manifest(args.spec)
except (OSError, ValueError) as e:
sys.stderr.write(f"lock: cannot read {args.spec}: {e}\n")
return EXIT_BAD
errors = validate_manifest(m)
if errors:
sys.stderr.write("lock: invalid manifest:\n - " + "\n - ".join(errors) + "\n")
return EXIT_BAD
h = manifest_hash(m)
sidecar = _sidecar_path(args.spec)
with open(sidecar, "w", encoding="utf-8") as fh:
fh.write(h + "\n")
print(f"locked: {args.spec}")
print(f" canonical bytes: {len(canonicalize(m).encode('utf-8'))}")
print(f" sha256: {h}")
print(f" sidecar: {sidecar}")
return EXIT_PASS
def cmd_hash(args) -> int:
try:
m = load_manifest(args.spec)
except (OSError, ValueError) as e:
sys.stderr.write(f"hash: cannot read {args.spec}: {e}\n")
return EXIT_BAD
print(manifest_hash(m))
return EXIT_PASS
def cmd_verify(args) -> int:
try:
m = load_manifest(args.spec)
except (OSError, ValueError) as e:
sys.stderr.write(f"verify: cannot read {args.spec}: {e}\n")
return EXIT_BAD
errors = validate_manifest(m)
if errors:
sys.stderr.write("verify: invalid manifest:\n - " + "\n - ".join(errors) + "\n")
return EXIT_BAD
recomputed = manifest_hash(m)
expected = args.expected_hash
if not expected:
sidecar = _sidecar_path(args.spec)
if not os.path.exists(sidecar):
sys.stderr.write(
f"verify: no --expected-hash and sidecar not found: {sidecar}\n"
f" run `falsify lock {args.spec}` first.\n"
)
return EXIT_GUARD
with open(sidecar, "r", encoding="utf-8") as fh:
expected = fh.read().strip()
if recomputed != expected:
print("TAMPERED")
print(f" recorded: {expected}")
print(f" recomputed: {recomputed}")
return EXIT_TAMPERED
if args.observed is None:
print(f"OK hash verified sha256:{recomputed}")
print("(no --observed value given; predicate not evaluated)")
return EXIT_PASS
try:
observed = float(args.observed)
except ValueError:
sys.stderr.write("verify: --observed must be a finite number\n")
return EXIT_BAD
if evaluate_predicate(observed, m["comparator"], m["threshold"]):
print(f"PASS metric={m['metric']} observed={observed} {m['comparator']} threshold={m['threshold']}")
return EXIT_PASS
print(f"FAIL metric={m['metric']} observed={observed} NOT {m['comparator']} threshold={m['threshold']}")
return EXIT_FAIL
_SKELETON = """\
version: prml/0.1
claim_id: REPLACE_WITH_UUIDv7
created_at: "2026-01-01T00:00:00Z"
metric: accuracy
comparator: ">="
threshold: 0.90
dataset:
id: your-dataset-id
hash: REPLACE_WITH_64_LOWERCASE_HEX
seed: 42
producer:
id: your-org-or-domain
"""
def cmd_init(args) -> int:
out = args.name if args.name.endswith((".yaml", ".yml", ".json")) else args.name + ".prml.yaml"
if os.path.exists(out):
sys.stderr.write(f"init: {out} already exists\n")
return EXIT_BAD
with open(out, "w", encoding="utf-8") as fh:
fh.write(_SKELETON)
print(f"wrote {out} — fill in the placeholders, then `falsify lock {out}`")
return EXIT_PASS
def cmd_test_vectors(args) -> int:
try:
with open(args.vectors, "r", encoding="utf-8") as fh:
data = json.load(fh)
except (OSError, ValueError) as e:
sys.stderr.write(f"test-vectors: cannot read {args.vectors}: {e}\n")
return EXIT_BAD
if isinstance(data, list):
vectors = data
elif isinstance(data, dict):
vectors = data.get("vectors", [])
else:
vectors = []
passed = 0
failed = 0
for v in vectors:
vid = v.get("id", "?")
manifest = v.get("input") or v.get("manifest")
exp_hash = v.get("hash")
if manifest is None or exp_hash is None:
continue
got = manifest_hash(manifest)
if got == exp_hash:
passed += 1
else:
failed += 1
print(f" FAIL {vid} expected {exp_hash[:12]} got {got[:12]}")
total = passed + failed
print(f"{'PASS' if failed == 0 else 'FAIL'} — {passed}/{total} passed")
return EXIT_PASS if failed == 0 else EXIT_FAIL
def main(argv=None) -> int:
p = argparse.ArgumentParser(prog="falsify", description="PRML reference CLI — pre-register ML eval claims.")
p.add_argument("--version", action="version", version=f"falsify {__version__} (PRML v0.1/v0.2)")
sub = p.add_subparsers(dest="cmd")
sp = sub.add_parser("lock", help="canonicalize, hash, write sidecar")
sp.add_argument("spec")
sp.set_defaults(func=cmd_lock)
sp = sub.add_parser("verify", help="verify hash; if --observed, evaluate the predicate")
sp.add_argument("spec")
sp.add_argument("--observed", default=None)
sp.add_argument("--expected-hash", dest="expected_hash", default=None)
sp.set_defaults(func=cmd_verify)
sp = sub.add_parser("hash", help="print the canonical SHA-256 only")
sp.add_argument("spec")
sp.set_defaults(func=cmd_hash)
sp = sub.add_parser("init", help="write a skeleton manifest")
sp.add_argument("name")
sp.set_defaults(func=cmd_init)
sp = sub.add_parser("test-vectors", help="run the conformance suite against a vectors.json")
sp.add_argument("vectors")
sp.set_defaults(func=cmd_test_vectors)
args = p.parse_args(argv)
if not getattr(args, "func", None):
p.print_help()
return EXIT_BAD
return args.func(args)
if __name__ == "__main__":
raise SystemExit(main())