Skip to content

Commit 6563e48

Browse files
committed
Implement Distribute
1 parent 23f114d commit 6563e48

5 files changed

Lines changed: 199 additions & 17 deletions

File tree

SYMBOLS_MANIFEST.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ System`Disk
348348
System`DiskBox
349349
System`DiskMatrix
350350
System`Dispatch
351+
System`Distribute
351352
System`Divide
352353
System`DivideBy
353354
System`Divisible

mathics/builtin/numbers/algebra.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@
4646
from mathics.core.systemsymbols import (
4747
SymbolAssumptions,
4848
SymbolEqual,
49+
SymbolIdentity,
4950
SymbolIndeterminate,
5051
SymbolLess,
5152
SymbolRule,
5253
SymbolRuleDelayed,
5354
SymbolTable,
5455
)
5556
from mathics.eval.list.eol import eval_Part
57+
from mathics.eval.numbers.algebra.distribute import eval_Distribute
5658
from mathics.eval.numbers.algebra.fraction import eval_Denominator, eval_Numerator
5759
from mathics.eval.numbers.algebra.options import AlgebraicOptions
5860
from mathics.eval.numbers.algebra.polynomial import (
@@ -550,7 +552,7 @@ class Collect(Builtin):
550552

551553
def eval_var_filter(self, expr, varlist, filt, evaluation):
552554
"""Collect[expr_, varlist_, filt_]"""
553-
if filt is Symbol("Identity"):
555+
if filt is SymbolIdentity:
554556
filt = None
555557
if isinstance(varlist, Symbol):
556558
var_exprs = [varlist]
@@ -664,6 +666,72 @@ def convert_options(self, options: dict, evaluation: Evaluation):
664666
return {"modulus": py_modulus, "trig": py_trig}
665667

666668

669+
class Distribute(Builtin):
670+
"""
671+
<url>:WMA link:https://reference.wolfram.com/language/ref/Distribute.html</url>
672+
673+
<dl>
674+
<dt>'Distribute'[$expr$]
675+
<dd>distributes $expr$ over 'Plus' (addition).
676+
<dt>'Distribute'[$expr$, $targetHead$]
677+
<dd>distributes $expr$ over the specified $targetHead$.
678+
<dt>'Distribute'[$expr$, $targetHead$, $f$]
679+
<dd>applies $f$ to each component of the result.
680+
</dl>
681+
682+
Distribute multiplication over addition:
683+
>> Distribute[a(b + c)]
684+
= a b + a c
685+
>> Distribute[(a + b)(c + d)]
686+
= a c + a d + b c + b d
687+
688+
Using a custom target head:
689+
>> Distribute[f[a + b, c], Plus]
690+
= f[a, c] + f[b, c]
691+
692+
# Distribute can also work with lists:
693+
# >> Distribute[{a(b + c), d(e + f)}]
694+
# = {a b + a c, d e + d f}
695+
696+
# Applying a function to results:
697+
# >> Distribute[a(b + c), Plus, Square]
698+
# = Square[a b] + Square[a c]
699+
700+
Special forms:
701+
>> Distribute[f[g[a + b]]]
702+
= f[g[a]] + f[g[b]]
703+
"""
704+
705+
attributes = A_PROTECTED
706+
707+
eval_error = Builtin.generic_argument_error
708+
expected_args = range(1, 6)
709+
710+
rules = {
711+
"Distribute[expr_]": "Distribute[expr, Plus]",
712+
"Distribute[expr_, operator_]": "Distribute[expr, operator, Identity]",
713+
}
714+
715+
summary_text = "distribute functions over a head"
716+
717+
def eval(self, expr, operator, filter, evaluation: Evaluation):
718+
"Distribute[expr_, operator_, filter_]"
719+
720+
# Handle Identity filter
721+
if filter is SymbolIdentity:
722+
filter = None
723+
724+
result = eval_Distribute(expr, operator, evaluation)
725+
726+
if result is None:
727+
return expr
728+
729+
if filter:
730+
return Expression(filter, result)
731+
732+
return result
733+
734+
667735
class Expand(_Expand):
668736
"""
669737
<url>

mathics/core/systemsymbols.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
SymbolHoldPattern = Symbol("System`HoldPattern")
164164
SymbolHue = Symbol("System`Hue")
165165
SymbolI = Symbol("System`I")
166+
SymbolIdentity = Symbol("System`Identity")
166167
SymbolIf = Symbol("System`If")
167168
SymbolIm = Symbol("System`Im")
168169
SymbolImage = Symbol("System`Image")
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Evaluation routines for Distribute[]
3+
"""
4+
5+
from mathics.core.expression import Expression
6+
from mathics.core.symbols import Symbol
7+
8+
9+
def eval_Distribute(expr, operator_symbol, evaluation):
10+
"""
11+
Recursively distribute operator_symbol over the expression.
12+
Returns None if no distribution was performed.
13+
"""
14+
if not isinstance(expr, Expression):
15+
return None
16+
17+
head = expr.get_head()
18+
elements = expr.elements
19+
20+
# Find the first element containing the operator_symbol.
21+
operator_position = None
22+
for i, elem in enumerate(elements):
23+
if contains_operator_symbol(elem, operator_symbol):
24+
operator_position = i
25+
break
26+
27+
if operator_position is None:
28+
# No element contains operator_symbol, but check if head itself needs distribution.
29+
return None
30+
31+
# Get the element at the target position
32+
target_elem = elements[operator_position]
33+
34+
# If the element is the operator symbol (e.g., Plus), distribute over it.
35+
if is_operator_symbol(target_elem, operator_symbol):
36+
# Get all components of the operator symbol
37+
target_components = target_elem.elements
38+
39+
# Create new expressions by replacing the operator position with each component.
40+
result_parts = []
41+
for component in target_components:
42+
# Replace the operator position with this component.
43+
new_elements = list(elements)
44+
new_elements[operator_position] = component
45+
new_expr = Expression(head, *new_elements)
46+
47+
# Recursively distribute in the new Expression.
48+
recursive_result = eval_Distribute(new_expr, operator_symbol, evaluation)
49+
if recursive_result is not None:
50+
result_parts.append(recursive_result)
51+
else:
52+
result_parts.append(new_expr)
53+
54+
# Return the combination using the operator symbol.
55+
return Expression(operator_symbol, *result_parts)
56+
57+
# If the element contains but is not the operator symbol, recurse into it.
58+
else:
59+
recursive_result = eval_Distribute(target_elem, operator_symbol, evaluation)
60+
if recursive_result is not None:
61+
new_elements = list(elements)
62+
new_elements[operator_position] = recursive_result
63+
new_expr = Expression(head, *new_elements)
64+
65+
# Try to distribute the modified Expression again
66+
second_result = eval_Distribute(new_expr, operator_symbol, evaluation)
67+
if second_result is not None:
68+
return second_result
69+
return new_expr
70+
71+
return None
72+
73+
74+
def is_operator_symbol(expr, operator_symbol):
75+
"""
76+
Check if expr's head is exactly the operator_symbol.
77+
"""
78+
if not isinstance(expr, Expression):
79+
return False
80+
81+
expr_head = expr.get_head()
82+
83+
if isinstance(operator_symbol, Symbol):
84+
return (
85+
isinstance(expr_head, Symbol)
86+
and expr_head.get_name() == operator_symbol.get_name()
87+
)
88+
89+
return expr_head == operator_symbol
90+
91+
92+
def contains_operator_symbol(expr, operator_symbol):
93+
"""
94+
Check if expr contains operator_symbol anywhere.
95+
"""
96+
if not isinstance(expr, Expression):
97+
return False
98+
99+
# Check if this expression's head is the target
100+
if is_operator_symbol(expr, operator_symbol):
101+
return True
102+
103+
# Recursively check sub-expressions
104+
for elem in expr.elements:
105+
if contains_operator_symbol(elem, operator_symbol):
106+
return True
107+
108+
return False

test/builtin/numbers/test_algebra.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_fullsimplify():
405405
[
406406
("Attributes[f] = {HoldAll}; Apart[f[x + x]]", None, "f[x + x]", None),
407407
("Attributes[f] = {}; Apart[f[x + x]]", None, "f[2 x]", None),
408-
## Errors:
408+
# Errors:
409409
(
410410
"Coefficient[x + y + 3]",
411411
("Coefficient called with 1 argument; 2 or 3 arguments are expected.",),
@@ -469,7 +469,7 @@ def test_fullsimplify():
469469
"24 x / (5 + 3 x + x ^ 2) ^ 3 + 8 x ^ 2 / (5 + 3 x + x ^ 2) ^ 3 + 18 / (5 + 3 x + x ^ 2) ^ 3",
470470
None,
471471
),
472-
## Modulus option
472+
# Modulus option
473473
(
474474
"ExpandDenominator[1 / (x + y)^3, Modulus -> 3]",
475475
None,
@@ -542,21 +542,21 @@ def test_fullsimplify():
542542
"True",
543543
None,
544544
),
545-
## TODO: MMA and Sympy handle these cases differently
546-
## #> PolynomialQ[x^(1/2) + 6xyz]
547-
## : No variable is not supported in PolynomialQ.
548-
## = True
549-
## #> PolynomialQ[x^(1/2) + 6xyz, {}]
550-
## : No variable is not supported in PolynomialQ.
551-
## = True
552-
## #> PolynomialQ[x^3 - 2 x/y + 3xz]
553-
## : No variable is not supported in PolynomialQ.
554-
## = False
555-
## #> PolynomialQ[x^3 - 2 x/y + 3xz, {}]
556-
## : No variable is not supported in PolynomialQ.
557-
## = False
545+
# TODO: MMA and Sympy handle these cases differently
546+
# #> PolynomialQ[x^(1/2) + 6xyz]
547+
# : No variable is not supported in PolynomialQ.
548+
# = True
549+
# #> PolynomialQ[x^(1/2) + 6xyz, {}]
550+
# : No variable is not supported in PolynomialQ.
551+
# = True
552+
# #> PolynomialQ[x^3 - 2 x/y + 3xz]
553+
# : No variable is not supported in PolynomialQ.
554+
# = False
555+
# #> PolynomialQ[x^3 - 2 x/y + 3xz, {}]
556+
# : No variable is not supported in PolynomialQ.
557+
# = False
558558
("f[x]/x+f[x]/x^2//Together", None, "f[x] (1 + x) / x ^ 2", None),
559-
## failing test case from MMA docs
559+
# failing test case from MMA docs
560560
("Variables[E^x]", None, "{}", None),
561561
],
562562
)
@@ -616,6 +616,10 @@ def test_integer(str_expr, msgs, str_expected, fail_msg):
616616
"Exponent",
617617
"2 or 3 arguments are",
618618
),
619+
(
620+
"Distribute",
621+
"between 1 and 5 arguments are",
622+
),
619623
],
620624
)
621625
def test_arg_count_errors(function_name, msg_fragment):

0 commit comments

Comments
 (0)