Skip to content

Commit 441ae2e

Browse files
committed
Add expression replacement via dictionary
1 parent 7a4687b commit 441ae2e

3 files changed

Lines changed: 48 additions & 5 deletions

File tree

dedalus/core/field.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def replace(self, old, new):
212212
"""Replace specified operand/operator."""
213213
raise NotImplementedError()
214214

215+
def replace_dict(self, subs):
216+
"""Replace specified operands/operators according to a dictionary."""
217+
raise NotImplementedError()
218+
215219
def sym_diff(self, var):
216220
"""Symbolically differentiate with respect to specified operand."""
217221
raise NotImplementedError()
@@ -284,9 +288,8 @@ def frechet_differential(self, variables, perturbations, backgrounds=None):
284288
# Compute differential
285289
epsilon = Field(dist=dist, dtype=dtype)
286290
# d/dε F(X0 + ε*X1)
287-
diff = self
288-
for var, pert in zip(variables, perturbations):
289-
diff = diff.replace(var, var + epsilon*pert)
291+
subs = {var: var + epsilon*pert for var, pert in zip(variables, perturbations)}
292+
diff = self.replace_dict(subs)
290293
diff = diff.sym_diff(epsilon)
291294
# ε -> 0
292295
if diff:
@@ -295,8 +298,8 @@ def frechet_differential(self, variables, perturbations, backgrounds=None):
295298
# Replace variables with backgrounds, if specified
296299
if diff:
297300
if backgrounds:
298-
for var, bg in zip(variables, backgrounds):
299-
diff = diff.replace(var, bg)
301+
subs = {var: bg for var, bg in zip(variables, backgrounds)}
302+
diff = diff.replace_dict(subs)
300303
return diff
301304

302305
@property
@@ -378,6 +381,13 @@ def replace(self, old, new):
378381
else:
379382
return self
380383

384+
def replace_dict(self, subs):
385+
"""Replace specified operands/operators according to a dictionary."""
386+
if self in subs:
387+
return subs[self]
388+
else:
389+
return self
390+
381391
def sym_diff(self, var):
382392
"""Symbolically differentiate with respect to specified operand."""
383393
if self == var:

dedalus/core/future.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ def replace(self, old, new):
128128
args = [arg.replace(old, new) if isinstance(arg, Operand) else arg for arg in self.args]
129129
return self.new_operands(*args)
130130

131+
def replace_dict(self, subs):
132+
"""Replace specified operands/operators according to a dictionary."""
133+
# Check for entire expression match
134+
if self in subs:
135+
return subs[self]
136+
# Check base and call with replaced arguments
137+
elif type(self) in subs:
138+
args = [arg.replace_dict(subs) if isinstance(arg, Operand) else arg for arg in self.args]
139+
return subs[type(self)](*args)
140+
# Call with replaced arguments
141+
else:
142+
args = [arg.replace_dict(subs) if isinstance(arg, Operand) else arg for arg in self.args]
143+
return self.new_operands(*args)
144+
131145
# def simplify(self, *vars):
132146
# """Simplify expression, except subtrees containing specified variables."""
133147
# # Simplify arguments if variables are present

dedalus/core/operators.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,20 @@ def replace(self, old, new):
701701
new_operand = self.operand.replace(old, new)
702702
return self.new_operand(new_operand)
703703

704+
def replace_dict(self, subs):
705+
"""Replace specified operands/operators according to a dictionary."""
706+
# Check for entire expression match
707+
if self in subs:
708+
return subs[self]
709+
# Check base and call with replaced arguments
710+
elif type(self) in subs:
711+
new_operand = self.operand.replace_dict(subs)
712+
return subs[type(self)](new_operand)
713+
# Call with replaced arguments
714+
else:
715+
new_operand = self.operand.replace_dict(subs)
716+
return self.new_operand(new_operand)
717+
704718
def expand(self, *vars):
705719
"""Expand expression over specified variables."""
706720
from .arithmetic import Add, Multiply
@@ -1611,6 +1625,11 @@ def replace(self, old, new):
16111625
# Replace operand, skipping conversion
16121626
return self.operand.replace(old, new)
16131627

1628+
def replace_dict(self, subs):
1629+
"""Replace specified operands/operators according to a dictionary."""
1630+
# Replace operand, skipping conversion
1631+
return self.operand.replace_dict(subs)
1632+
16141633
# def expand(self, *vars):
16151634
# """Expand expression over specified variables."""
16161635
# # Expand operand, skipping conversion

0 commit comments

Comments
 (0)