diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py
index b808aa3..8846ae2 100644
--- a/src/e3sm_quickview/pipeline.py
+++ b/src/e3sm_quickview/pipeline.py
@@ -7,6 +7,16 @@
from vtkmodules.vtkRenderingCore import vtkActor, vtkPolyDataMapper
+def range_to_trim(range, max):
+ """
+ Convert an range interval contained in [-180, 180] (such as [-170, 160])
+ into an interval specifying how much to trim (for our example: [10, 20]).
+ 'max' is 180 for longitude or 90 for latitude
+ """
+ (min_range, max_range) = range
+ return [min_range + max, max - max_range]
+
+
def load_plugins():
try:
plugin_dir = Path(__file__).with_name("plugins")
@@ -134,8 +144,8 @@ def __init__(self, projection="Mollweide"):
)
self._crop = simple.EAMExtract(
Input=self.center_meridian,
- LongitudeRange=[-180, 180],
- LatitudeRange=[-90.0, 90.0],
+ TrimLongitude=[0, 0],
+ TrimLatitude=[0, 0],
)
self.proj = simple.EAMProject( # noqa: F821
Input=self._crop,
@@ -233,8 +243,8 @@ def update(self, time=0.0):
self.geometry.UpdatePipeline(time)
def crop(self, longitude_min_max, latitude_min_max):
- self._crop.LongitudeRange = longitude_min_max
- self._crop.LatitudeRange = latitude_min_max
+ self._crop.TrimLongitude = range_to_trim(longitude_min_max, 180)
+ self._crop.TrimLatitude = range_to_trim(latitude_min_max, 90)
class EAMVisSource:
diff --git a/src/e3sm_quickview/plugins/eam_projection.py b/src/e3sm_quickview/plugins/eam_projection.py
index cb2e600..f9b2ac1 100644
--- a/src/e3sm_quickview/plugins/eam_projection.py
+++ b/src/e3sm_quickview/plugins/eam_projection.py
@@ -1,23 +1,25 @@
from paraview.simple import *
from paraview.util.vtkAlgorithm import *
-from vtkmodules.numpy_interface import dataset_adapter as dsa
from vtkmodules.vtkCommonCore import (
vtkPoints,
)
from vtkmodules.vtkCommonDataModel import (
- vtkPolyData,
vtkCellArray,
+ vtkDataSetAttributes,
vtkPlane,
+ vtkPolyData,
)
from vtkmodules.vtkCommonTransforms import vtkTransform
from vtkmodules.vtkFiltersCore import (
vtkAppendFilter,
+ vtkCellCenters,
vtkGenerateIds,
)
from vtkmodules.vtkFiltersGeneral import (
- vtkTransformFilter,
vtkTableBasedClipDataSet,
+ vtkTransformFilter,
)
+from vtkmodules.vtkFiltersGeometry import vtkGeometryFilter
try:
from paraview.modules.vtkPVVTKExtensionsFiltersGeneral import vtkPVClipDataSet
@@ -104,19 +106,10 @@ def RequestData(self, request, inInfo, outInfo):
else:
outData.DeepCopy(inData)
- inWrap = dsa.WrapDataObject(inData)
- outWrap = dsa.WrapDataObject(outData)
-
- inPoints = np.array(inWrap.Points)
+ inPoints = inData.points
pRadius = (self.radius + 1) if self.isData else self.radius
outPoints = np.array(list(map(lambda x: ProcessPoint(x, pRadius), inPoints)))
-
- _coords = numpy_support.numpy_to_vtk(
- outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT
- )
- vtk_coords = vtkPoints()
- vtk_coords.SetData(_coords)
- outWrap.SetPoints(vtk_coords)
+ outData.points = outPoints
return 1
@@ -157,10 +150,7 @@ def RequestData(self, request, inInfo, outInfo):
outData = self.GetOutputData(outInfo, 0)
outData.DeepCopy(inData)
- inWrap = dsa.WrapDataObject(inData)
- outWrap = dsa.WrapDataObject(outData)
-
- inPoints = np.array(inWrap.Points)
+ inPoints = inData.points
pRadius = (self.radius + 1) if self.isData else self.radius
outPoints = np.array(list(map(lambda x: ProcessPoint(x, pRadius), inPoints)))
# outPoints = np.array(list(map(ProcessPoint,inPoints)))
@@ -170,7 +160,7 @@ def RequestData(self, request, inInfo, outInfo):
)
vtk_coords = vtkPoints()
vtk_coords.SetData(_coords)
- outWrap.SetPoints(vtk_coords)
+ outData.points = vtk_coords
return 1
@@ -255,6 +245,7 @@ def __init__(self):
self.__Dims = -1
self.project = 0
self.translate = False
+ self.cached_points = None
def SetTranslation(self, translate):
if self.translate != translate:
@@ -273,49 +264,53 @@ def RequestData(self, request, inInfo, outInfo):
afilter = vtkAppendFilter()
afilter.AddInputData(inData)
afilter.Update()
- outData.DeepCopy(afilter.GetOutput())
+ outData.ShallowCopy(afilter.GetOutput())
else:
- outData.DeepCopy(inData)
-
+ outData.ShallowCopy(inData)
if self.project == 0:
return 1
-
- inWrap = dsa.WrapDataObject(inData)
- outWrap = dsa.WrapDataObject(outData)
- inPoints = np.array(inWrap.Points)
-
- flat = inPoints.flatten()
- x = flat[0::3] - 180.0 if self.translate else flat[0::3]
- y = flat[1::3]
-
- try:
- # Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
- latlon = Proj(proj="latlong", datum="WGS84")
- if self.project == 1:
- proj = Proj(proj="robin")
- elif self.project == 2:
- proj = Proj(proj="moll")
- else:
- # Should not reach here, but return without transformation
+ if (
+ self.cached_points
+ and self.cached_points.GetMTime() >= inData.GetPoints().GetMTime()
+ ):
+ outData.SetPoints(self.cached_points)
+ else:
+ # we modify the points, so copy them
+ out_points_vtk = vtkPoints()
+ out_points_vtk.DeepCopy(inData.GetPoints())
+ outData.SetPoints(out_points_vtk)
+ out_points_np = outData.points
+
+ flat = out_points_np.flatten()
+ x = flat[0::3] - 180.0 if self.translate else flat[0::3]
+ y = flat[1::3]
+
+ try:
+ # Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
+ latlon = Proj(proj="latlong", datum="WGS84")
+ if self.project == 1:
+ proj = Proj(proj="robin")
+ elif self.project == 2:
+ proj = Proj(proj="moll")
+ else:
+ # Should not reach here, but return without transformation
+ return 1
+
+ xformer = Transformer.from_proj(latlon, proj, always_xy=True)
+ res = xformer.transform(x, y)
+ except Exception as e:
+ print(f"Projection error: {e}")
+ # If projection fails, return without modifying coordinates
return 1
-
- xformer = Transformer.from_proj(latlon, proj, always_xy=True)
- res = xformer.transform(x, y)
- except Exception as e:
- print(f"Projection error: {e}")
- # If projection fails, return without modifying coordinates
- return 1
- flat[0::3] = np.array(res[0])
- flat[1::3] = np.array(res[1])
-
- outPoints = flat.reshape(inPoints.shape)
- _coords = numpy_support.numpy_to_vtk(
- outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT
- )
- vtk_coords = vtkPoints()
- vtk_coords.SetData(_coords)
- outWrap.SetPoints(vtk_coords)
-
+ flat[0::3] = np.array(res[0])
+ flat[1::3] = np.array(res[1])
+
+ outPoints = flat.reshape(out_points_np.shape)
+ _coords = numpy_support.numpy_to_vtk(outPoints, deep=True)
+ outData.GetPoints().SetData(_coords)
+ # the previous cached_points, if any, is available for
+ # garbage collection after this assignment
+ self.cached_points = out_points_vtk
return 1
@@ -415,15 +410,15 @@ def RequestData(self, request, inInfo, outInfo):
@smdomain.datatype(dataTypes=["vtkPolyData"], composite_data_supported=False)
@smproperty.xml(
"""
-
+ default_values="0 0">
-
+ default_values="0 0">
"""
)
@@ -432,44 +427,95 @@ def __init__(self):
super().__init__(
nInputPorts=1, nOutputPorts=1, outputType="vtkUnstructuredGrid"
)
- self.longrange = [-180.0, 180.0]
- self.latrange = [-90.0, 90.0]
+ self.trim_lon = [0, 0]
+ self.trim_lat = [0, 0]
+ self.cached_cell_centers = None
+ self.cached_ghosts = None
- def SetLongitudeRange(self, min, max):
- if self.longrange[0] != min or self.longrange[1] != max:
- self.longrange = [min, max]
+ def SetTrimLongitude(self, left, right):
+ if left < 0 or left > 180 or right < 0 or right > 180:
+ print_error(
+ f"SetTrimLongitude called with parameters outside [0, 180]: {left=}, {right=}"
+ )
+ return
+ if self.trim_lon[0] != left or self.trim_lon[1] != right:
+ self.trim_lon = [left, right]
self.Modified()
- def SetLatitudeRange(self, min, max):
- if self.latrange[0] != min or self.latrange[1] != max:
- self.latrange = [min, max]
+ def SetTrimLatitude(self, left, right):
+ if left < 0 or left > 90 or right < 0 or right > 90:
+ print_error(
+ f"SetTrimLatitude called with parameters outside [0, 180]: {left=}, {right=}"
+ )
+ return
+ if self.trim_lat[0] != left or self.trim_lat[1] != right:
+ self.trim_lat = [left, right]
self.Modified()
def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
outData = self.GetOutputData(outInfo, 0)
- if self.longrange == [-180.0, 180] and self.latrange == [-90, 90]:
+ if self.trim_lon == [0, 0] and self.trim_lat == [0, 0]:
outData.ShallowCopy(inData)
return 1
- box = vtkPVBox()
- box.SetReferenceBounds(
- self.longrange[0],
- self.longrange[1],
- self.latrange[0],
- self.latrange[1],
- -1.0,
- 1.0,
- )
- box.SetUseReferenceBounds(True)
- extract = vtkPVClipDataSet()
- extract.SetClipFunction(box)
- extract.InsideOutOn()
- extract.ExactBoxClipOn()
- extract.SetInputData(inData)
- extract.Update()
+ outData.ShallowCopy(inData)
+ if self.cached_cell_centers and self.cached_cell_centers.GetMTime() >= max(
+ inData.GetPoints().GetMTime(), inData.GetCells().GetMTime()
+ ):
+ cell_centers = self.cached_cell_centers
+ else:
+ # convert to polydata, as vtkCellCenters only works on polydata
+ # import pdb;pdb.set_trace()
+ to_poly = vtkGeometryFilter()
+ to_poly.SetInputData(inData)
+
+ # get cell centers
+ compute_centers = vtkCellCenters()
+ compute_centers.SetInputConnection(to_poly.GetOutputPort())
+ compute_centers.Update()
+ cell_centers = compute_centers.GetOutput().GetPoints().GetData()
+ # previous cached_cell_centers, if any,
+ # is available for garbage collection after this assignment
+ self.cached_cell_centers = cell_centers
+
+ # get the numpy array for cell centers
+ cc = numpy_support.vtk_to_numpy(cell_centers)
+
+ if self.cached_ghosts and self.cached_ghosts.GetMTime() >= max(
+ self.GetMTime(), inData.GetPoints().GetMTime(), cell_centers.GetMTime()
+ ):
+ ghost = self.cached_ghosts
+ else:
+ # import pdb;pdb.set_trace()
+ # compute the new bounds by trimming the inData bounds
+ bounds = list(inData.GetBounds())
+ bounds[0] = bounds[0] + self.trim_lon[0]
+ bounds[1] = bounds[1] - self.trim_lon[1]
+ bounds[2] = bounds[2] + self.trim_lat[0]
+ bounds[3] = bounds[3] - self.trim_lat[1]
+
+ # add hidden cells based on bounds
+ outside_mask = (
+ (cc[:, 0] < bounds[0])
+ | (cc[:, 0] > bounds[1])
+ | (cc[:, 1] < bounds[2])
+ | (cc[:, 1] > bounds[3])
+ )
+
+ # Create ghost array (0 = visible, HIDDENCELL = invisible)
+ ghost_np = np.where(
+ outside_mask, vtkDataSetAttributes.HIDDENCELL, 0
+ ).astype(np.uint8)
+
+ # Convert to VTK and add to output
+ ghost = numpy_support.numpy_to_vtk(ghost_np)
+ ghost.SetName(vtkDataSetAttributes.GhostArrayName())
+ # the previous cached_ghosts, if any,
+ # is available for garbage collection after this assignment
+ self.cached_ghosts = ghost
+ outData.GetCellData().AddArray(ghost)
- outData.ShallowCopy(extract.GetOutput())
return 1
@@ -535,14 +581,14 @@ def GetMeridian(self):
def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
- inPoints = inData.GetPoints()
- inCellArray = inData.GetCells()
outData = self.GetOutputData(outInfo, 0)
if (
self._cached_output
- and self._cached_output.GetMTime() > inPoints.GetMTime()
- and self._cached_output.GetMTime() > inCellArray.GetMTime()
+ and self._cached_output.GetPoints().GetMTime()
+ >= inData.GetPoints().GetMTime()
+ and self._cached_output.GetCells().GetMTime()
+ >= inData.GetCells().GetMTime()
):
# only scalars have been added or removed
cached_cell_data = self._cached_output.GetCellData()
@@ -606,6 +652,7 @@ def RequestData(self, request, inInfo, outInfo):
append.AddInputData(transform.GetOutput())
append.Update()
outData.ShallowCopy(append.GetOutput())
+ # previous _cached_output is available for garbage collection
self._cached_output = outData.NewInstance()
self._cached_output.ShallowCopy(outData)
return 1