|
23 | 23 | from tsfc.parameters import default_parameters, is_complex |
24 | 24 | from tsfc.ufl_utils import apply_mapping, extract_firedrake_constants |
25 | 25 | import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy |
| 26 | +from tsfc.exceptions import MismatchingDomainError |
| 27 | + |
26 | 28 |
|
27 | 29 | # To handle big forms. The various transformations might need a deeper stack |
28 | 30 | sys.setrecursionlimit(3000) |
@@ -90,6 +92,9 @@ def compile_form(form, prefix="form", parameters=None, dont_split_numbers=(), di |
90 | 92 | complex_mode=complex_mode, |
91 | 93 | ) |
92 | 94 | logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time) |
| 95 | + |
| 96 | + validate_domains(form_data.preprocessed_form) |
| 97 | + |
93 | 98 | # Create local kernels. |
94 | 99 | kernels = [] |
95 | 100 | for integral_data in form_data.integral_data: |
@@ -135,27 +140,16 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F |
135 | 140 | if coeff in form_data.coefficient_split: |
136 | 141 | coefficient_split[coeff] = form_data.coefficient_split[coeff] |
137 | 142 | coefficient_numbers.append(form_data.original_coefficient_positions[i]) |
138 | | - |
139 | 143 | mesh = integral_data.domain |
140 | 144 | all_meshes = extract_domains(form_data.original_form) |
141 | 145 | domain_number = all_meshes.index(mesh) |
142 | | - coefficient_meshes = chain.from_iterable(map(extract_domains, coefficients)) |
143 | | - |
144 | | - domain_integral_type_map = dict.fromkeys(all_meshes, None) |
145 | | - domain_integral_type_map.update(dict.fromkeys(coefficient_meshes, "cell")) |
146 | | - domain_integral_type_map.update(integral_data.domain_integral_type_map) |
147 | | - |
148 | | - for arg in arguments: |
149 | | - if domain_integral_type_map[extract_unique_domain(arg)] is None: |
150 | | - raise NotImplementedError("Assembly of forms over unrelated meshes is not supported. " |
151 | | - "Try using Submeshes or cross-mesh interpolation.") |
152 | 146 |
|
153 | 147 | integral_data_info = TSFCIntegralDataInfo( |
154 | 148 | domain=integral_data.domain, |
155 | 149 | integral_type=integral_data.integral_type, |
156 | 150 | subdomain_id=integral_data.subdomain_id, |
157 | 151 | domain_number=domain_number, |
158 | | - domain_integral_type_map=domain_integral_type_map, |
| 152 | + domain_integral_type_map={mesh: integral_data.domain_integral_type_map.get(mesh, None) for mesh in all_meshes}, |
159 | 153 | arguments=arguments, |
160 | 154 | coefficients=coefficients, |
161 | 155 | coefficient_split=coefficient_split, |
@@ -186,6 +180,31 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F |
186 | 180 | return builder.construct_kernel(kernel_name, ctx, parameters["add_petsc_events"]) |
187 | 181 |
|
188 | 182 |
|
| 183 | +def validate_domains(form): |
| 184 | + if len(extract_domains(form)) == 1: |
| 185 | + # Not a multi-domain form, we do not need to keep checking |
| 186 | + return |
| 187 | + |
| 188 | + for itg in form.integrals(): |
| 189 | + # Check that all domains are related to each other |
| 190 | + domain = itg.ufl_domain() |
| 191 | + for other_domain in itg.extra_domain_integral_type_map(): |
| 192 | + if domain.submesh_youngest_common_ancester(other_domain) is None: |
| 193 | + raise MismatchingDomainError("Assembly of forms over unrelated meshes is not supported. " |
| 194 | + "Try using Submeshes or cross-mesh interpolation.") |
| 195 | + |
| 196 | + # Check that all Arguments and Coefficients are defined on the valid domains |
| 197 | + valid_domains = set(itg.extra_domain_integral_type_map()) |
| 198 | + valid_domains.add(domain) |
| 199 | + |
| 200 | + itg_domains = set(extract_domains(itg)) |
| 201 | + if len(itg_domains - valid_domains) > 0: |
| 202 | + raise MismatchingDomainError("Argument or Coefficient domain not found in integral. " |
| 203 | + "Possibly, the form contains coefficients on different meshes " |
| 204 | + "and requires measure intersection, for example: " |
| 205 | + 'Measure("dx", argument_mesh, intersect_measures=[Measure("dx", coefficient_mesh)]).') |
| 206 | + |
| 207 | + |
189 | 208 | def preprocess_parameters(parameters): |
190 | 209 | if parameters is None: |
191 | 210 | parameters = default_parameters() |
|
0 commit comments