|
3 | 3 | import numpy |
4 | 4 | import os |
5 | 5 | import ufl |
6 | | -import weakref |
7 | 6 | from itertools import chain |
8 | 7 | from pyop2.mpi import COMM_WORLD, dup_comm |
9 | 8 | from pyop2.datatypes import IntType |
@@ -414,53 +413,41 @@ def __init__(self, filename, project_output=False, comm=None, mode="w", |
414 | 413 |
|
415 | 414 | self._fnames = None |
416 | 415 | self._topology = None |
417 | | - self._output_functions = weakref.WeakKeyDictionary() |
418 | | - self._mappers = weakref.WeakKeyDictionary() |
419 | 416 |
|
420 | 417 | def _prepare_output(self, function, max_elem): |
421 | 418 | from firedrake import FunctionSpace, VectorFunctionSpace, \ |
422 | | - TensorFunctionSpace, Function, Projector, Interpolator |
| 419 | + TensorFunctionSpace, Function |
| 420 | + from tsfc.finatinterface import create_element as create_finat_element |
423 | 421 |
|
424 | 422 | name = function.name() |
425 | 423 | # Need to project/interpolate? |
426 | | - # If space is not the max element, we can do so. |
427 | | - if function.ufl_element == max_elem: |
| 424 | + # If space is not the max element, we must do so. |
| 425 | + finat_elem = function.function_space().finat_element |
| 426 | + if finat_elem == create_finat_element(max_elem): |
428 | 427 | return OFunction(array=get_array(function), |
429 | 428 | name=name, function=function) |
430 | 429 | # OK, let's go and do it. |
| 430 | + # Build appropriate space for output function. |
431 | 431 | shape = function.ufl_shape |
432 | | - output = self._output_functions.get(function) |
433 | | - if output is None: |
434 | | - # Build appropriate space for output function. |
435 | | - shape = function.ufl_shape |
436 | | - if len(shape) == 0: |
437 | | - V = FunctionSpace(function.ufl_domain(), max_elem) |
438 | | - elif len(shape) == 1: |
439 | | - if numpy.prod(shape) > 3: |
440 | | - raise ValueError("Can't write vectors with more than 3 components") |
441 | | - V = VectorFunctionSpace(function.ufl_domain(), max_elem, |
442 | | - dim=shape[0]) |
443 | | - elif len(shape) == 2: |
444 | | - if numpy.prod(shape) > 9: |
445 | | - raise ValueError("Can't write tensors with more than 9 components") |
446 | | - V = TensorFunctionSpace(function.ufl_domain(), max_elem, |
447 | | - shape=shape) |
448 | | - else: |
449 | | - raise ValueError("Unsupported shape %s" % (shape, )) |
450 | | - output = Function(V) |
451 | | - self._output_functions[function] = output |
| 432 | + if len(shape) == 0: |
| 433 | + V = FunctionSpace(function.ufl_domain(), max_elem) |
| 434 | + elif len(shape) == 1: |
| 435 | + if numpy.prod(shape) > 3: |
| 436 | + raise ValueError("Can't write vectors with more than 3 components") |
| 437 | + V = VectorFunctionSpace(function.ufl_domain(), max_elem, |
| 438 | + dim=shape[0]) |
| 439 | + elif len(shape) == 2: |
| 440 | + if numpy.prod(shape) > 9: |
| 441 | + raise ValueError("Can't write tensors with more than 9 components") |
| 442 | + V = TensorFunctionSpace(function.ufl_domain(), max_elem, |
| 443 | + shape=shape) |
| 444 | + else: |
| 445 | + raise ValueError("Unsupported shape %s" % (shape, )) |
| 446 | + output = Function(V) |
452 | 447 | if self.project: |
453 | | - projector = self._mappers.get(function) |
454 | | - if projector is None: |
455 | | - projector = Projector(function, output) |
456 | | - self._mappers[function] = projector |
457 | | - projector.project() |
| 448 | + output.project(function) |
458 | 449 | else: |
459 | | - interpolator = self._mappers.get(function) |
460 | | - if interpolator is None: |
461 | | - interpolator = Interpolator(function, output) |
462 | | - self._mappers[function] = interpolator |
463 | | - interpolator.interpolate() |
| 450 | + output.interpolate(function) |
464 | 451 |
|
465 | 452 | return OFunction(array=get_array(output), name=name, function=output) |
466 | 453 |
|
|
0 commit comments