Skip to content

Commit ed26ae9

Browse files
authored
Merge pull request #1764 from firedrakeproject/better-vtk-output
Do not store functions in File object
2 parents cfefb6d + 76754cf commit ed26ae9

File tree

1 file changed

+23
-36
lines changed

1 file changed

+23
-36
lines changed

firedrake/output.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy
44
import os
55
import ufl
6-
import weakref
76
from itertools import chain
87
from pyop2.mpi import COMM_WORLD, dup_comm
98
from pyop2.datatypes import IntType
@@ -414,53 +413,41 @@ def __init__(self, filename, project_output=False, comm=None, mode="w",
414413

415414
self._fnames = None
416415
self._topology = None
417-
self._output_functions = weakref.WeakKeyDictionary()
418-
self._mappers = weakref.WeakKeyDictionary()
419416

420417
def _prepare_output(self, function, max_elem):
421418
from firedrake import FunctionSpace, VectorFunctionSpace, \
422-
TensorFunctionSpace, Function, Projector, Interpolator
419+
TensorFunctionSpace, Function
420+
from tsfc.finatinterface import create_element as create_finat_element
423421

424422
name = function.name()
425423
# Need to project/interpolate?
426-
# If space is not the max element, we can do so.
427-
if function.ufl_element == max_elem:
424+
# If space is not the max element, we must do so.
425+
finat_elem = function.function_space().finat_element
426+
if finat_elem == create_finat_element(max_elem):
428427
return OFunction(array=get_array(function),
429428
name=name, function=function)
430429
# OK, let's go and do it.
430+
# Build appropriate space for output function.
431431
shape = function.ufl_shape
432-
output = self._output_functions.get(function)
433-
if output is None:
434-
# Build appropriate space for output function.
435-
shape = function.ufl_shape
436-
if len(shape) == 0:
437-
V = FunctionSpace(function.ufl_domain(), max_elem)
438-
elif len(shape) == 1:
439-
if numpy.prod(shape) > 3:
440-
raise ValueError("Can't write vectors with more than 3 components")
441-
V = VectorFunctionSpace(function.ufl_domain(), max_elem,
442-
dim=shape[0])
443-
elif len(shape) == 2:
444-
if numpy.prod(shape) > 9:
445-
raise ValueError("Can't write tensors with more than 9 components")
446-
V = TensorFunctionSpace(function.ufl_domain(), max_elem,
447-
shape=shape)
448-
else:
449-
raise ValueError("Unsupported shape %s" % (shape, ))
450-
output = Function(V)
451-
self._output_functions[function] = output
432+
if len(shape) == 0:
433+
V = FunctionSpace(function.ufl_domain(), max_elem)
434+
elif len(shape) == 1:
435+
if numpy.prod(shape) > 3:
436+
raise ValueError("Can't write vectors with more than 3 components")
437+
V = VectorFunctionSpace(function.ufl_domain(), max_elem,
438+
dim=shape[0])
439+
elif len(shape) == 2:
440+
if numpy.prod(shape) > 9:
441+
raise ValueError("Can't write tensors with more than 9 components")
442+
V = TensorFunctionSpace(function.ufl_domain(), max_elem,
443+
shape=shape)
444+
else:
445+
raise ValueError("Unsupported shape %s" % (shape, ))
446+
output = Function(V)
452447
if self.project:
453-
projector = self._mappers.get(function)
454-
if projector is None:
455-
projector = Projector(function, output)
456-
self._mappers[function] = projector
457-
projector.project()
448+
output.project(function)
458449
else:
459-
interpolator = self._mappers.get(function)
460-
if interpolator is None:
461-
interpolator = Interpolator(function, output)
462-
self._mappers[function] = interpolator
463-
interpolator.interpolate()
450+
output.interpolate(function)
464451

465452
return OFunction(array=get_array(output), name=name, function=output)
466453

0 commit comments

Comments
 (0)