Skip to content

Commit 009177d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9920fd1 commit 009177d

12 files changed

Lines changed: 316 additions & 330 deletions

python/egglog/bindings.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ class Value:
176176
def __ge__(self, other: object) -> bool: ...
177177

178178
@final
179-
180179
@final
181180
class EggSmolError(Exception):
182181
context: str

python/egglog/egraph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,11 +1315,7 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
13151315
"""
13161316
(output,) = self._run_program(bindings.PrintSize(span(1), None))
13171317
assert isinstance(output, bindings.PrintAllFunctionsSize)
1318-
return [
1319-
(callables[0], size)
1320-
for (name, size) in output.sizes
1321-
if (callables := self._egg_fn_to_callables(name))
1322-
]
1318+
return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
13231319

13241320
def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
13251321
return [

python/egglog/exp/param_eq/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44

55
from .pipeline import _cli
66

7-
87
if __name__ == "__main__":
98
_cli()

python/egglog/exp/param_eq/normalize_archives.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,11 @@ def _normalize_runtime_rows(source_dir: Path) -> list[dict[str, str]]:
124124
node_count = int(benchmark_name.split("/")[-1])
125125
_, numeric, unit, *_ = time_line.split()
126126
runtime_ms = float(numeric) * _to_runtime_multiplier(unit)
127-
rows.append(
128-
{
129-
"benchmark_name": benchmark_name,
130-
"node_count": str(node_count),
131-
"runtime_ms": f"{runtime_ms:.9f}",
132-
}
133-
)
127+
rows.append({
128+
"benchmark_name": benchmark_name,
129+
"node_count": str(node_count),
130+
"runtime_ms": f"{runtime_ms:.9f}",
131+
})
134132
return rows
135133

136134

python/egglog/exp/param_eq/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def _const_propagation(
457457
GuardCases: TypeAlias = tuple[GuardConditions, ...]
458458
_CONST_GUARD_COUNTER = count()
459459

460+
460461
def _fresh_const_guard_value(prefix: str = "_const_value") -> f64:
461462
return var(f"{prefix}_{next(_CONST_GUARD_COUNTER)}", f64)
462463

@@ -1269,6 +1270,7 @@ def _serialized_counts(egraph: egglog.EGraph) -> tuple[int, int]:
12691270
payload = json.loads(egraph._serialize().to_json())
12701271
return len(payload.get("nodes", {})), len(payload.get("class_data", {}))
12711272

1273+
12721274
analysis_schedule = const_merge_rules | const_seed_rules | const_propagation_rules | const_prune_rules
12731275
basic_rules = (
12741276
basic_add_comm_rules

python/egglog/exp/param_eq/replication.ipynb

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
"\n",
6060
"from egglog.exp.param_eq.paths import ARTIFACT_DIR, PARAM_EQ_DIR\n",
6161
"\n",
62-
"\n",
6362
"alt.data_transformers.disable_max_rows()\n",
6463
"alt.renderers.enable(\"default\")\n",
6564
"\n",
@@ -273,16 +272,14 @@
273272
" percent = float(\"nan\")\n",
274273
" if not eligible.empty:\n",
275274
" percent = 100.0 * (eligible[\"simpl_rank\"] <= delta).sum() / len(eligible)\n",
276-
" rows.append(\n",
277-
" {\n",
278-
" \"implementation\": implementation,\n",
279-
" \"dataset\": dataset,\n",
280-
" \"dataset_label\": DATASET_LABELS[dataset],\n",
281-
" \"algorithm\": algorithm,\n",
282-
" \"delta\": f\"Δ {'==' if delta == 0 else '<='} {delta}\",\n",
283-
" \"percent\": percent,\n",
284-
" }\n",
285-
" )\n",
275+
" rows.append({\n",
276+
" \"implementation\": implementation,\n",
277+
" \"dataset\": dataset,\n",
278+
" \"dataset_label\": DATASET_LABELS[dataset],\n",
279+
" \"algorithm\": algorithm,\n",
280+
" \"delta\": f\"Δ {'==' if delta == 0 else '<='} {delta}\",\n",
281+
" \"percent\": percent,\n",
282+
" })\n",
286283
" result = pd.DataFrame(rows)\n",
287284
" result[\"percent_label\"] = result[\"percent\"].map(lambda value: \"n/a\" if pd.isna(value) else f\"{value:.2f}%\")\n",
288285
" return result\n",
@@ -309,7 +306,9 @@
309306
" display(SVG(buffer.getvalue()))\n",
310307
"\n",
311308
"\n",
312-
"archived_haskell = with_implementation(add_paper_metrics(_paper_haskell_frame(ARCHIVED_HASKELL_PATH)), \"Archived Haskell\")\n",
309+
"archived_haskell = with_implementation(\n",
310+
" add_paper_metrics(_paper_haskell_frame(ARCHIVED_HASKELL_PATH)), \"Archived Haskell\"\n",
311+
")\n",
313312
"live_haskell = with_implementation(add_paper_metrics(_paper_haskell_frame(LIVE_HASKELL_PATH)), \"Live Haskell\")\n",
314313
"egglog = with_implementation(add_paper_metrics(_paper_egglog_frame(EGGLOG_PATH)), \"Egglog\")\n",
315314
"runtime_rows = _paper_runtime_frame()\n",
@@ -375,7 +374,7 @@
375374
" comparison_table(egglog, implementation=\"Egglog\"),\n",
376375
" ],\n",
377376
" ignore_index=True,\n",
378-
")\n"
377+
")"
379378
]
380379
},
381380
{
@@ -535,58 +534,54 @@
535534
}
536535
],
537536
"source": [
538-
"artifact_summary = pd.DataFrame(\n",
539-
" [\n",
537+
"artifact_summary = pd.DataFrame([\n",
538+
" {\n",
539+
" \"implementation\": \"Archived Haskell\",\n",
540+
" \"rows\": len(archived_haskell),\n",
541+
" \"original_median_simpl_rank\": float(archived_haskell[\"simpl_rank\"].median()),\n",
542+
" \"sympy_median_simpl_rank\": float(archived_haskell[\"sympy_rank\"].median()),\n",
543+
" },\n",
544+
" {\n",
545+
" \"implementation\": \"Live Haskell\",\n",
546+
" \"rows\": len(live_haskell),\n",
547+
" \"original_median_simpl_rank\": float(live_haskell[\"simpl_rank\"].median()),\n",
548+
" \"sympy_median_simpl_rank\": float(live_haskell[\"sympy_rank\"].median()),\n",
549+
" },\n",
550+
" {\n",
551+
" \"implementation\": \"Egglog\",\n",
552+
" \"rows\": len(egglog),\n",
553+
" \"original_median_simpl_rank\": float(egglog[\"simpl_rank\"].median()),\n",
554+
" \"sympy_median_simpl_rank\": float(egglog[\"sympy_rank\"].median()),\n",
555+
" },\n",
556+
"])\n",
557+
"display(artifact_summary)\n",
558+
"display(\n",
559+
" pd.DataFrame([\n",
540560
" {\n",
541-
" \"implementation\": \"Archived Haskell\",\n",
542-
" \"rows\": len(archived_haskell),\n",
543-
" \"original_median_simpl_rank\": float(archived_haskell[\"simpl_rank\"].median()),\n",
544-
" \"sympy_median_simpl_rank\": float(archived_haskell[\"sympy_rank\"].median()),\n",
561+
" \"comparison\": \"Egglog vs live Haskell (original)\",\n",
562+
" \"exact_matches\": int(egglog_vs_live[\"orig_exact\"].sum()),\n",
563+
" \"total_rows\": len(egglog_vs_live),\n",
564+
" \"max_gap\": int(egglog_vs_live[\"orig_gap\"].max()),\n",
545565
" },\n",
546566
" {\n",
547-
" \"implementation\": \"Live Haskell\",\n",
548-
" \"rows\": len(live_haskell),\n",
549-
" \"original_median_simpl_rank\": float(live_haskell[\"simpl_rank\"].median()),\n",
550-
" \"sympy_median_simpl_rank\": float(live_haskell[\"sympy_rank\"].median()),\n",
567+
" \"comparison\": \"Egglog vs live Haskell (sympy)\",\n",
568+
" \"exact_matches\": int(egglog_vs_live[\"sympy_exact\"].sum()),\n",
569+
" \"total_rows\": len(egglog_vs_live),\n",
570+
" \"max_gap\": int(egglog_vs_live[\"sympy_gap\"].max()),\n",
551571
" },\n",
552572
" {\n",
553-
" \"implementation\": \"Egglog\",\n",
554-
" \"rows\": len(egglog),\n",
555-
" \"original_median_simpl_rank\": float(egglog[\"simpl_rank\"].median()),\n",
556-
" \"sympy_median_simpl_rank\": float(egglog[\"sympy_rank\"].median()),\n",
573+
" \"comparison\": \"Live vs archived Haskell (original)\",\n",
574+
" \"exact_matches\": int((archive_drift[\"orig_drift\"] == 0).sum()),\n",
575+
" \"total_rows\": len(archive_drift),\n",
576+
" \"max_gap\": int(archive_drift[\"orig_drift\"].abs().max()),\n",
557577
" },\n",
558-
" ]\n",
559-
")\n",
560-
"display(artifact_summary)\n",
561-
"display(\n",
562-
" pd.DataFrame(\n",
563-
" [\n",
564-
" {\n",
565-
" \"comparison\": \"Egglog vs live Haskell (original)\",\n",
566-
" \"exact_matches\": int(egglog_vs_live[\"orig_exact\"].sum()),\n",
567-
" \"total_rows\": len(egglog_vs_live),\n",
568-
" \"max_gap\": int(egglog_vs_live[\"orig_gap\"].max()),\n",
569-
" },\n",
570-
" {\n",
571-
" \"comparison\": \"Egglog vs live Haskell (sympy)\",\n",
572-
" \"exact_matches\": int(egglog_vs_live[\"sympy_exact\"].sum()),\n",
573-
" \"total_rows\": len(egglog_vs_live),\n",
574-
" \"max_gap\": int(egglog_vs_live[\"sympy_gap\"].max()),\n",
575-
" },\n",
576-
" {\n",
577-
" \"comparison\": \"Live vs archived Haskell (original)\",\n",
578-
" \"exact_matches\": int((archive_drift[\"orig_drift\"] == 0).sum()),\n",
579-
" \"total_rows\": len(archive_drift),\n",
580-
" \"max_gap\": int(archive_drift[\"orig_drift\"].abs().max()),\n",
581-
" },\n",
582-
" {\n",
583-
" \"comparison\": \"Live vs archived Haskell (sympy)\",\n",
584-
" \"exact_matches\": int((archive_drift[\"sympy_drift\"] == 0).sum()),\n",
585-
" \"total_rows\": len(archive_drift),\n",
586-
" \"max_gap\": int(archive_drift[\"sympy_drift\"].abs().max()),\n",
587-
" },\n",
588-
" ]\n",
589-
" )\n",
578+
" {\n",
579+
" \"comparison\": \"Live vs archived Haskell (sympy)\",\n",
580+
" \"exact_matches\": int((archive_drift[\"sympy_drift\"] == 0).sum()),\n",
581+
" \"total_rows\": len(archive_drift),\n",
582+
" \"max_gap\": int(archive_drift[\"sympy_drift\"].abs().max()),\n",
583+
" },\n",
584+
" ])\n",
590585
")\n",
591586
"# -"
592587
]
@@ -896,7 +891,12 @@
896891
}
897892
],
898893
"source": [
899-
"show_chart(runtime_chart(runtime_compare, title=\"Pagie runtime versus expression size (Figure 9 analog) across archived Haskell, live Haskell, and Egglog\"))"
894+
"show_chart(\n",
895+
" runtime_chart(\n",
896+
" runtime_compare,\n",
897+
" title=\"Pagie runtime versus expression size (Figure 9 analog) across archived Haskell, live Haskell, and Egglog\",\n",
898+
" )\n",
899+
")"
900900
]
901901
},
902902
{

python/egglog/exp/param_eq/replication.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
from egglog.exp.param_eq.paths import ARTIFACT_DIR, PARAM_EQ_DIR
4040

41-
4241
alt.data_transformers.disable_max_rows()
4342
alt.renderers.enable("default")
4443

@@ -252,16 +251,14 @@ def comparison_table(frame: pd.DataFrame, *, implementation: str) -> pd.DataFram
252251
percent = float("nan")
253252
if not eligible.empty:
254253
percent = 100.0 * (eligible["simpl_rank"] <= delta).sum() / len(eligible)
255-
rows.append(
256-
{
257-
"implementation": implementation,
258-
"dataset": dataset,
259-
"dataset_label": DATASET_LABELS[dataset],
260-
"algorithm": algorithm,
261-
"delta": f"Δ {'==' if delta == 0 else '<='} {delta}",
262-
"percent": percent,
263-
}
264-
)
254+
rows.append({
255+
"implementation": implementation,
256+
"dataset": dataset,
257+
"dataset_label": DATASET_LABELS[dataset],
258+
"algorithm": algorithm,
259+
"delta": f"Δ {'==' if delta == 0 else '<='} {delta}",
260+
"percent": percent,
261+
})
265262
result = pd.DataFrame(rows)
266263
result["percent_label"] = result["percent"].map(lambda value: "n/a" if pd.isna(value) else f"{value:.2f}%")
267264
return result
@@ -288,7 +285,9 @@ def show_chart(chart: Any) -> None:
288285
display(SVG(buffer.getvalue()))
289286

290287

291-
archived_haskell = with_implementation(add_paper_metrics(_paper_haskell_frame(ARCHIVED_HASKELL_PATH)), "Archived Haskell")
288+
archived_haskell = with_implementation(
289+
add_paper_metrics(_paper_haskell_frame(ARCHIVED_HASKELL_PATH)), "Archived Haskell"
290+
)
292291
live_haskell = with_implementation(add_paper_metrics(_paper_haskell_frame(LIVE_HASKELL_PATH)), "Live Haskell")
293292
egglog = with_implementation(add_paper_metrics(_paper_egglog_frame(EGGLOG_PATH)), "Egglog")
294293
runtime_rows = _paper_runtime_frame()
@@ -360,58 +359,54 @@ def show_chart(chart: Any) -> None:
360359

361360
# ## 1. Artifact Overview
362361

363-
artifact_summary = pd.DataFrame(
364-
[
362+
artifact_summary = pd.DataFrame([
363+
{
364+
"implementation": "Archived Haskell",
365+
"rows": len(archived_haskell),
366+
"original_median_simpl_rank": float(archived_haskell["simpl_rank"].median()),
367+
"sympy_median_simpl_rank": float(archived_haskell["sympy_rank"].median()),
368+
},
369+
{
370+
"implementation": "Live Haskell",
371+
"rows": len(live_haskell),
372+
"original_median_simpl_rank": float(live_haskell["simpl_rank"].median()),
373+
"sympy_median_simpl_rank": float(live_haskell["sympy_rank"].median()),
374+
},
375+
{
376+
"implementation": "Egglog",
377+
"rows": len(egglog),
378+
"original_median_simpl_rank": float(egglog["simpl_rank"].median()),
379+
"sympy_median_simpl_rank": float(egglog["sympy_rank"].median()),
380+
},
381+
])
382+
display(artifact_summary)
383+
display(
384+
pd.DataFrame([
365385
{
366-
"implementation": "Archived Haskell",
367-
"rows": len(archived_haskell),
368-
"original_median_simpl_rank": float(archived_haskell["simpl_rank"].median()),
369-
"sympy_median_simpl_rank": float(archived_haskell["sympy_rank"].median()),
386+
"comparison": "Egglog vs live Haskell (original)",
387+
"exact_matches": int(egglog_vs_live["orig_exact"].sum()),
388+
"total_rows": len(egglog_vs_live),
389+
"max_gap": int(egglog_vs_live["orig_gap"].max()),
370390
},
371391
{
372-
"implementation": "Live Haskell",
373-
"rows": len(live_haskell),
374-
"original_median_simpl_rank": float(live_haskell["simpl_rank"].median()),
375-
"sympy_median_simpl_rank": float(live_haskell["sympy_rank"].median()),
392+
"comparison": "Egglog vs live Haskell (sympy)",
393+
"exact_matches": int(egglog_vs_live["sympy_exact"].sum()),
394+
"total_rows": len(egglog_vs_live),
395+
"max_gap": int(egglog_vs_live["sympy_gap"].max()),
376396
},
377397
{
378-
"implementation": "Egglog",
379-
"rows": len(egglog),
380-
"original_median_simpl_rank": float(egglog["simpl_rank"].median()),
381-
"sympy_median_simpl_rank": float(egglog["sympy_rank"].median()),
398+
"comparison": "Live vs archived Haskell (original)",
399+
"exact_matches": int((archive_drift["orig_drift"] == 0).sum()),
400+
"total_rows": len(archive_drift),
401+
"max_gap": int(archive_drift["orig_drift"].abs().max()),
382402
},
383-
]
384-
)
385-
display(artifact_summary)
386-
display(
387-
pd.DataFrame(
388-
[
389-
{
390-
"comparison": "Egglog vs live Haskell (original)",
391-
"exact_matches": int(egglog_vs_live["orig_exact"].sum()),
392-
"total_rows": len(egglog_vs_live),
393-
"max_gap": int(egglog_vs_live["orig_gap"].max()),
394-
},
395-
{
396-
"comparison": "Egglog vs live Haskell (sympy)",
397-
"exact_matches": int(egglog_vs_live["sympy_exact"].sum()),
398-
"total_rows": len(egglog_vs_live),
399-
"max_gap": int(egglog_vs_live["sympy_gap"].max()),
400-
},
401-
{
402-
"comparison": "Live vs archived Haskell (original)",
403-
"exact_matches": int((archive_drift["orig_drift"] == 0).sum()),
404-
"total_rows": len(archive_drift),
405-
"max_gap": int(archive_drift["orig_drift"].abs().max()),
406-
},
407-
{
408-
"comparison": "Live vs archived Haskell (sympy)",
409-
"exact_matches": int((archive_drift["sympy_drift"] == 0).sum()),
410-
"total_rows": len(archive_drift),
411-
"max_gap": int(archive_drift["sympy_drift"].abs().max()),
412-
},
413-
]
414-
)
403+
{
404+
"comparison": "Live vs archived Haskell (sympy)",
405+
"exact_matches": int((archive_drift["sympy_drift"] == 0).sum()),
406+
"total_rows": len(archive_drift),
407+
"max_gap": int(archive_drift["sympy_drift"].abs().max()),
408+
},
409+
])
415410
)
416411
# -
417412

@@ -510,7 +505,12 @@ def show_chart(chart: Any) -> None:
510505
# met.
511506

512507
# +
513-
show_chart(runtime_chart(runtime_compare, title="Pagie runtime versus expression size (Figure 9 analog) across archived Haskell, live Haskell, and Egglog"))
508+
show_chart(
509+
runtime_chart(
510+
runtime_compare,
511+
title="Pagie runtime versus expression size (Figure 9 analog) across archived Haskell, live Haskell, and Egglog",
512+
)
513+
)
514514
# -
515515

516516
# This faceted Figure 9 analog is now closer to the archived benchmark setup:

0 commit comments

Comments
 (0)