Skip to content

Commit 22d9066

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 784b9ab commit 22d9066

38 files changed

Lines changed: 1457 additions & 1526 deletions

python/egglog/bindings.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ class Value:
181181
def __ge__(self, other: object) -> bool: ...
182182

183183
@final
184-
185184
@final
186185
class EggSmolError(Exception):
187186
context: str

python/egglog/builtins.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,6 @@ def match(self, f: Callable[[T], V], n: V) -> V: ...
499499
# def flat_map(self, f: Callable[[T], Maybe[V]]) -> Maybe[V]: ...
500500

501501

502-
503502
converter(type(None), Maybe, lambda _: Maybe[get_type_args()[0]].none())
504503
# converter(object, Maybe, lambda x: Maybe[get_type_args()[0]].some(convert(x, get_type_args()[0])))
505504

@@ -719,7 +718,9 @@ def map_best_common_float_scale(xs: Map[T, f64]) -> f64: ...
719718

720719

721720
@function(egg_fn="map-integer-residual-split-candidate", builtin=True)
722-
def map_integer_residual_split_candidate(xs: Map[Map[T, BigRat], f64]) -> Pair[Map[T, BigRat], Map[Map[T, BigRat], f64]]: ...
721+
def map_integer_residual_split_candidate(
722+
xs: Map[Map[T, BigRat], f64],
723+
) -> Pair[Map[T, BigRat], Map[Map[T, BigRat], f64]]: ...
723724

724725

725726
@function(egg_fn="map-factor-coef-for-integer-residual-split", builtin=True)

python/egglog/deconstruct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def _deconstruct_call_decl(
196196

197197
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
198198

199+
199200
def is_expr_instance(x: BaseExpr, cls: type[T]) -> TypeIs[T]:
200201
"""
201202
Checks if the expression is an instance of the given class. Can normally use isinstance for this, but this also works for

python/egglog/egraph.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@
142142
}
143143

144144

145-
def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *actions: ActionLike, display=False) -> EGraph:
145+
def check_eq(
146+
x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *actions: ActionLike, display=False
147+
) -> EGraph:
146148
"""
147149
Verifies that two expressions are equal after running the schedule.
148150
@@ -478,9 +480,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
478480
subsume=False,
479481
)
480482
resolved_default = (
481-
resolve_literal(type_ref, default_value, Thunk.value(decls))
482-
if default_mode == "eager"
483-
else None
483+
resolve_literal(type_ref, default_value, Thunk.value(decls)) if default_mode == "eager" else None
484484
)
485485
if resolved_default is not None:
486486
decls |= resolved_default
@@ -911,11 +911,7 @@ def _constant_thunk(
911911
unextractable=False,
912912
subsume=False,
913913
)
914-
resolved_default = (
915-
resolve_literal(type_ref, default_replacement, Thunk.value(decls))
916-
if mode == "eager"
917-
else None
918-
)
914+
resolved_default = resolve_literal(type_ref, default_replacement, Thunk.value(decls)) if mode == "eager" else None
919915
if resolved_default is not None:
920916
decls |= resolved_default
921917
decls._constants[ident] = ConstantDecl(
@@ -1596,11 +1592,7 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
15961592
"""
15971593
(output,) = self._state.run_program(bindings.PrintSize(span(1), None))
15981594
assert isinstance(output, bindings.PrintAllFunctionsSize)
1599-
return [
1600-
(callables[0], size)
1601-
for (name, size) in output.sizes
1602-
if (callables := self._egg_fn_to_callables(name))
1603-
]
1595+
return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))]
16041596

16051597
def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]:
16061598
return [

python/egglog/egraph_state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def ruleset_to_egg(self, ident: Ident) -> None:
410410
def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | None:
411411
match cmd:
412412
case ActionCommandDecl(action):
413-
action_egg = self.action_to_egg(action)# , expr_to_let=False)
413+
action_egg = self.action_to_egg(action) # , expr_to_let=False)
414414
if not action_egg:
415415
return None
416416
return bindings.ActionCommand(action_egg)
@@ -641,9 +641,9 @@ def _primitive_command_to_egg(
641641
signature: FunctionSignature,
642642
body: TypedExprDecl,
643643
) -> bindings.UserDefined:
644-
input_sort_expr = self._primitive_input_sorts_to_egg(
645-
[self.type_ref_to_egg(arg_type.to_just()) for arg_type in signature.arg_types]
646-
)
644+
input_sort_expr = self._primitive_input_sorts_to_egg([
645+
self.type_ref_to_egg(arg_type.to_just()) for arg_type in signature.arg_types
646+
])
647647
output_sort_expr = bindings.Var(span(), self.type_ref_to_egg(signature.semantic_return_type.to_just()))
648648
return bindings.UserDefined(
649649
span(),
Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,47 @@
1-
_expr_1 = Value.var("q2") * Value.var("bp1") + Value.var("q5") * Value.var("bp2") + Value.var("q8") * Value.var("bp3") + Value.var("bp4") * Value.var("q11")
2-
_expr_2 = Value.var("bpp2") * Value.var("q6") + Value.var("q3") * Value.var("bpp1") + Value.var("bpp3") * Value.var("q9") + Value.var("bpp4") * Value.var("q12")
3-
_expr_3 = Value.var("bp1") * Value.var("q3") + Value.var("bp2") * Value.var("q6") + Value.var("q12") * Value.var("bp4") + Value.var("bp3") * Value.var("q9")
4-
_expr_4 = Value.var("bpp2") * Value.var("q5") + Value.var("q2") * Value.var("bpp1") + Value.var("bpp3") * Value.var("q8") + Value.var("bpp4") * Value.var("q11")
5-
_expr_5 = Value.var("q4") * Value.var("bpp2") + Value.var("q7") * Value.var("bpp3") + Value.var("q10") * Value.var("bpp4") + Value.var("bpp1") * Value.var("q1")
6-
_expr_6 = Value.var("q4") * Value.var("bp2") + Value.var("q7") * Value.var("bp3") + Value.var("bp1") * Value.var("q1") + Value.var("q10") * Value.var("bp4")
1+
_expr_1 = (
2+
Value.var("q2") * Value.var("bp1")
3+
+ Value.var("q5") * Value.var("bp2")
4+
+ Value.var("q8") * Value.var("bp3")
5+
+ Value.var("bp4") * Value.var("q11")
6+
)
7+
_expr_2 = (
8+
Value.var("bpp2") * Value.var("q6")
9+
+ Value.var("q3") * Value.var("bpp1")
10+
+ Value.var("bpp3") * Value.var("q9")
11+
+ Value.var("bpp4") * Value.var("q12")
12+
)
13+
_expr_3 = (
14+
Value.var("bp1") * Value.var("q3")
15+
+ Value.var("bp2") * Value.var("q6")
16+
+ Value.var("q12") * Value.var("bp4")
17+
+ Value.var("bp3") * Value.var("q9")
18+
)
19+
_expr_4 = (
20+
Value.var("bpp2") * Value.var("q5")
21+
+ Value.var("q2") * Value.var("bpp1")
22+
+ Value.var("bpp3") * Value.var("q8")
23+
+ Value.var("bpp4") * Value.var("q11")
24+
)
25+
_expr_5 = (
26+
Value.var("q4") * Value.var("bpp2")
27+
+ Value.var("q7") * Value.var("bpp3")
28+
+ Value.var("q10") * Value.var("bpp4")
29+
+ Value.var("bpp1") * Value.var("q1")
30+
)
31+
_expr_6 = (
32+
Value.var("q4") * Value.var("bp2")
33+
+ Value.var("q7") * Value.var("bp3")
34+
+ Value.var("bp1") * Value.var("q1")
35+
+ Value.var("q10") * Value.var("bp4")
36+
)
737
NDArray(
838
RecursiveValue(
939
(
1040
(_expr_1 * _expr_2 + Value.from_int(Int(-1)) * (_expr_3 * _expr_4)) ** Value.from_int(Int(2))
1141
+ (_expr_3 * _expr_5 + Value.from_int(Int(-1)) * (_expr_6 * _expr_2)) ** Value.from_int(Int(2))
1242
+ (_expr_6 * _expr_4 + Value.from_int(Int(-1)) * (_expr_1 * _expr_5)) ** Value.from_int(Int(2))
1343
)
14-
/ (_expr_6 ** Value.from_int(Int(2)) + _expr_1 ** Value.from_int(Int(2)) + _expr_3 ** Value.from_int(Int(2))) ** Value.from_int(Int(3))
44+
/ (_expr_6 ** Value.from_int(Int(2)) + _expr_1 ** Value.from_int(Int(2)) + _expr_3 ** Value.from_int(Int(2)))
45+
** Value.from_int(Int(3))
1546
)
16-
)
47+
)

python/egglog/exp/param_eq/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44

55
from .pipeline import _cli
66

7-
87
if __name__ == "__main__":
98
_cli()

0 commit comments

Comments
 (0)