Skip to content

Commit b51d98d

Browse files
authored
Do not set domains of all active coefficients in domain_integral_type (#4803)
* Raise useful exception * Revert #4775
1 parent c5ef5d1 commit b51d98d

8 files changed

Lines changed: 50 additions & 22 deletions

File tree

firedrake/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def init_petsc():
6767
from firedrake.cofunction import Cofunction, RieszMap # noqa: F401
6868
from firedrake.constant import Constant # noqa: F401
6969
from firedrake.deflation import DeflatedSNES, Deflation # noqa: F401
70-
from firedrake.exceptions import ConvergenceError # noqa: F401
70+
from firedrake.exceptions import ConvergenceError, MismatchingDomainError # noqa: F401
7171
from firedrake.function import ( # noqa: F401
7272
Function, PointNotInDomainError,
7373
CoordinatelessFunction, PointEvaluator

firedrake/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from tsfc.exceptions import MismatchingDomainError # noqa: F401
12

23

34
class ConvergenceError(Exception):

tests/firedrake/regression/test_multiple_domains.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def test_mismatching_meshes_indexed_function(mesh1, mesh3):
4141
with pytest.raises(NotImplementedError):
4242
project(d1, target)
4343

44-
with pytest.raises(NotImplementedError):
44+
with pytest.raises(MismatchingDomainError):
4545
assemble(inner(d1, TestFunction(V2))*dx(domain=mesh3))
4646

47-
with pytest.raises(NotImplementedError):
47+
with pytest.raises(MismatchingDomainError):
4848
assemble(inner(d1, TestFunction(V2))*dx(domain=mesh1))
4949

5050

@@ -177,29 +177,29 @@ def test_multi_domain_assemble():
177177

178178
for i, j in [(0, 1), (1, 0)]:
179179
a1 = inner(u[i], v[j])*dx(domain=mesh1)
180-
with pytest.raises(NotImplementedError):
180+
with pytest.raises(MismatchingDomainError):
181181
assemble(a1)
182182
a2 = inner(u[i], v[j])*dx(domain=mesh2)
183-
with pytest.raises(NotImplementedError):
183+
with pytest.raises(MismatchingDomainError):
184184
assemble(a2)
185185
l1 = inner(f[i], v[j])*dx(domain=mesh1)
186-
with pytest.raises(NotImplementedError):
186+
with pytest.raises(MismatchingDomainError):
187187
assemble(l1)
188188
l2 = inner(f[i], v[j])*dx(domain=mesh2)
189-
with pytest.raises(NotImplementedError):
189+
with pytest.raises(MismatchingDomainError):
190190
assemble(l2)
191191

192192
for i, j in [(0, 0), (1, 1)]:
193193
a = inner(u[i], v[j])*dx(domain=mesh1)
194194
if i == 1:
195-
with pytest.raises(NotImplementedError):
195+
with pytest.raises(MismatchingDomainError):
196196
assemble(a)
197197
continue
198198
A = assemble(a)
199199
assert A.M.values.shape == (V.dim(), V.dim())
200200

201201
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[0], v[1])*dx(domain=mesh2)
202-
with pytest.raises(NotImplementedError):
202+
with pytest.raises(MismatchingDomainError):
203203
assemble(a)
204204

205205
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[1], v[1])*dx(domain=mesh2)

tests/firedrake/submesh/test_submesh_assemble.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,9 @@ def expr(m):
555555

556556
Q = FunctionSpace(mesh, "DG", 0)
557557
q = Function(Q).interpolate(expr(mesh))
558-
A = assemble(inner(grad(usub) * q, grad(vsub))*dx(domain=submesh))
558+
559+
subdx = Measure("dx", submesh, intersect_measures=(Measure("dx", mesh),))
560+
A = assemble(inner(grad(usub) * q, grad(vsub))*subdx)
559561

560562
Qsub = FunctionSpace(submesh, "DG", 0)
561563
qsub = Function(Qsub).interpolate(expr(submesh))

tsfc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from tsfc.driver import compile_form, compile_expression_dual_evaluation # noqa: F401
22
from tsfc.parameters import default_parameters # noqa: F401
3+
from tsfc.exceptions import MismatchingDomainError # noqa: F401
34

45

56
def register_citations():

tsfc/driver.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tsfc.parameters import default_parameters, is_complex
2424
from tsfc.ufl_utils import apply_mapping, extract_firedrake_constants
2525
import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy
26+
from tsfc.exceptions import MismatchingDomainError
27+
2628

2729
# To handle big forms. The various transformations might need a deeper stack
2830
sys.setrecursionlimit(3000)
@@ -90,6 +92,9 @@ def compile_form(form, prefix="form", parameters=None, dont_split_numbers=(), di
9092
complex_mode=complex_mode,
9193
)
9294
logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)
95+
96+
validate_domains(form_data.preprocessed_form)
97+
9398
# Create local kernels.
9499
kernels = []
95100
for integral_data in form_data.integral_data:
@@ -135,27 +140,16 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
135140
if coeff in form_data.coefficient_split:
136141
coefficient_split[coeff] = form_data.coefficient_split[coeff]
137142
coefficient_numbers.append(form_data.original_coefficient_positions[i])
138-
139143
mesh = integral_data.domain
140144
all_meshes = extract_domains(form_data.original_form)
141145
domain_number = all_meshes.index(mesh)
142-
coefficient_meshes = chain.from_iterable(map(extract_domains, coefficients))
143-
144-
domain_integral_type_map = dict.fromkeys(all_meshes, None)
145-
domain_integral_type_map.update(dict.fromkeys(coefficient_meshes, "cell"))
146-
domain_integral_type_map.update(integral_data.domain_integral_type_map)
147-
148-
for arg in arguments:
149-
if domain_integral_type_map[extract_unique_domain(arg)] is None:
150-
raise NotImplementedError("Assembly of forms over unrelated meshes is not supported. "
151-
"Try using Submeshes or cross-mesh interpolation.")
152146

153147
integral_data_info = TSFCIntegralDataInfo(
154148
domain=integral_data.domain,
155149
integral_type=integral_data.integral_type,
156150
subdomain_id=integral_data.subdomain_id,
157151
domain_number=domain_number,
158-
domain_integral_type_map=domain_integral_type_map,
152+
domain_integral_type_map={mesh: integral_data.domain_integral_type_map.get(mesh, None) for mesh in all_meshes},
159153
arguments=arguments,
160154
coefficients=coefficients,
161155
coefficient_split=coefficient_split,
@@ -186,6 +180,31 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
186180
return builder.construct_kernel(kernel_name, ctx, parameters["add_petsc_events"])
187181

188182

183+
def validate_domains(form):
184+
if len(extract_domains(form)) == 1:
185+
# Not a multi-domain form, we do not need to keep checking
186+
return
187+
188+
for itg in form.integrals():
189+
# Check that all domains are related to each other
190+
domain = itg.ufl_domain()
191+
for other_domain in itg.extra_domain_integral_type_map():
192+
if domain.submesh_youngest_common_ancester(other_domain) is None:
193+
raise MismatchingDomainError("Assembly of forms over unrelated meshes is not supported. "
194+
"Try using Submeshes or cross-mesh interpolation.")
195+
196+
# Check that all Arguments and Coefficients are defined on the valid domains
197+
valid_domains = set(itg.extra_domain_integral_type_map())
198+
valid_domains.add(domain)
199+
200+
itg_domains = set(extract_domains(itg))
201+
if len(itg_domains - valid_domains) > 0:
202+
raise MismatchingDomainError("Argument or Coefficient domain not found in integral. "
203+
"Possibly, the form contains coefficients on different meshes "
204+
"and requires measure intersection, for example: "
205+
'Measure("dx", argument_mesh, intersect_measures=[Measure("dx", coefficient_mesh)]).')
206+
207+
189208
def preprocess_parameters(parameters):
190209
if parameters is None:
191210
parameters = default_parameters()

tsfc/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
3+
class MismatchingDomainError(Exception):
4+
"""Error raised for unsupported multidomain problems"""

tsfc/kernel_interface/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def coefficient(self, ufl_coefficient, restriction):
5757
expressions."""
5858
kernel_arg = self.coefficient_map[ufl_coefficient]
5959
domain = extract_unique_domain(ufl_coefficient)
60+
assert self._domain_integral_type_map[domain] is not None
6061
if ufl_coefficient.ufl_element().family() == 'Real':
6162
return kernel_arg
6263
elif not self._domain_integral_type_map[domain].startswith("interior_facet"):

0 commit comments

Comments
 (0)