Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FunctionSpace,
Identity,
Index,
Interpolate,
Jacobian,
JacobianInverse,
Label,
Expand Down Expand Up @@ -64,6 +65,7 @@
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
from ufl.classes import Indexed, MultiIndex, ReferenceGrad
from ufl.constantvalue import Zero, as_ufl
from ufl.differentiation import BaseFormDerivative
from ufl.domain import extract_unique_domain
from ufl.operators import Variable
from ufl.pullback import identity_pullback
Expand Down Expand Up @@ -1015,3 +1017,23 @@ def test_variable_label():
F_var_4 = Variable(F, label=Label(888))
dCdF_4 = apply_derivatives(diff(C, F_var_4))
assert dCdF_4 == 0


def test_base_form_derivative_equality():
domain = Mesh(LagrangeElement(triangle, 1, (2,)))
f1 = FiniteElement("CG", triangle, 1, (), identity_pullback, H1)
V1 = FunctionSpace(domain, f1)
f2 = FiniteElement("CG", triangle, 2, (), identity_pullback, H1)
V2 = FunctionSpace(domain, f2)

u = Coefficient(V1)
Iu = Interpolate(u, V2)
uhat = TrialFunction(V1)

dIu = derivative(Iu, u, uhat)
assert isinstance(dIu, BaseFormDerivative)

assert derivative(Iu, u, uhat) == dIu

v = Coefficient(V1)
assert derivative(Iu, v, uhat) != dIu
17 changes: 16 additions & 1 deletion test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from ufl.algorithms.expand_indices import expand_indices
from ufl.core.interpolate import Interpolate
from ufl.form import Form, FormSum
from ufl.form import Form, FormSum, ZeroBaseForm
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1

Expand Down Expand Up @@ -199,6 +199,21 @@ def test_differentiation(V1, V2):
assert dJdu.arguments() == (Argument(V1, 0),)


def test_differentiation_wrt_independent_variable(V1, V2):
u = Coefficient(V1)
Iu = Interpolate(u, V2)
uhat = TrialFunction(V1)

independent = Coefficient(V1)

dIu = derivative(Iu, independent, uhat)
dIu_expanded = expand_derivatives(dIu)
assert isinstance(dIu_expanded, ZeroBaseForm)

interp_deriv_sum_expanded = expand_derivatives(dIu + dIu)
assert isinstance(interp_deriv_sum_expanded, ZeroBaseForm)


def test_extract_base_form_operators(V1, V2):
u = Coefficient(V1)
uhat = TrialFunction(V1)
Expand Down
9 changes: 8 additions & 1 deletion ufl/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ufl.checks import is_true_ufl_scalar, is_ufl_scalar
from ufl.constantvalue import ComplexValue, IntValue, ScalarValue, Zero, as_ufl, zero
from ufl.core.expr import ufl_err_str
from ufl.core.expr import Expr, ufl_err_str
from ufl.core.operator import Operator
from ufl.core.ufl_type import ufl_type
from ufl.index_combination_utils import merge_unique_indices
Expand All @@ -34,6 +34,13 @@ class Sum(Operator):

def __new__(cls, a, b):
"""Create a new Sum."""
from ufl import BaseForm, FormSum

# Base forms (that aren't also expressions like Interpolate) should
# be cast to FormSums instead.
if any(isinstance(x, BaseForm) and not isinstance(x, Expr) for x in [a, b]):
return FormSum((a, 1), (b, 1))

# Make sure everything is an Expr
a = as_ufl(a)
b = as_ufl(b)
Expand Down
8 changes: 7 additions & 1 deletion ufl/core/base_form_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,13 @@ def __eq__(self, other):
"""Check for equality."""
if isinstance(other, Number) and other == 0:
return self.empty()
raise NotImplementedError()
return (
type(other) is type(self)
and other.ufl_operands == self.ufl_operands
and other._argument_slots == self._argument_slots
and other.derivatives == self.derivatives
and other.ufl_function_space() == self.ufl_function_space()
)

@property
def _parent_type(self):
Expand Down
Loading