diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index b808aa3..8846ae2 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -7,6 +7,16 @@ from vtkmodules.vtkRenderingCore import vtkActor, vtkPolyDataMapper +def range_to_trim(range, max): + """ + Convert an range interval contained in [-180, 180] (such as [-170, 160]) + into an interval specifying how much to trim (for our example: [10, 20]). + 'max' is 180 for longitude or 90 for latitude + """ + (min_range, max_range) = range + return [min_range + max, max - max_range] + + def load_plugins(): try: plugin_dir = Path(__file__).with_name("plugins") @@ -134,8 +144,8 @@ def __init__(self, projection="Mollweide"): ) self._crop = simple.EAMExtract( Input=self.center_meridian, - LongitudeRange=[-180, 180], - LatitudeRange=[-90.0, 90.0], + TrimLongitude=[0, 0], + TrimLatitude=[0, 0], ) self.proj = simple.EAMProject( # noqa: F821 Input=self._crop, @@ -233,8 +243,8 @@ def update(self, time=0.0): self.geometry.UpdatePipeline(time) def crop(self, longitude_min_max, latitude_min_max): - self._crop.LongitudeRange = longitude_min_max - self._crop.LatitudeRange = latitude_min_max + self._crop.TrimLongitude = range_to_trim(longitude_min_max, 180) + self._crop.TrimLatitude = range_to_trim(latitude_min_max, 90) class EAMVisSource: diff --git a/src/e3sm_quickview/plugins/eam_projection.py b/src/e3sm_quickview/plugins/eam_projection.py index cb2e600..f9b2ac1 100644 --- a/src/e3sm_quickview/plugins/eam_projection.py +++ b/src/e3sm_quickview/plugins/eam_projection.py @@ -1,23 +1,25 @@ from paraview.simple import * from paraview.util.vtkAlgorithm import * -from vtkmodules.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import ( vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( - vtkPolyData, vtkCellArray, + vtkDataSetAttributes, vtkPlane, + vtkPolyData, ) from vtkmodules.vtkCommonTransforms import vtkTransform from vtkmodules.vtkFiltersCore import ( vtkAppendFilter, + vtkCellCenters, vtkGenerateIds, ) from vtkmodules.vtkFiltersGeneral import ( - vtkTransformFilter, vtkTableBasedClipDataSet, + vtkTransformFilter, ) +from vtkmodules.vtkFiltersGeometry import vtkGeometryFilter try: from paraview.modules.vtkPVVTKExtensionsFiltersGeneral import vtkPVClipDataSet @@ -104,19 +106,10 @@ def RequestData(self, request, inInfo, outInfo): else: outData.DeepCopy(inData) - inWrap = dsa.WrapDataObject(inData) - outWrap = dsa.WrapDataObject(outData) - - inPoints = np.array(inWrap.Points) + inPoints = inData.points pRadius = (self.radius + 1) if self.isData else self.radius outPoints = np.array(list(map(lambda x: ProcessPoint(x, pRadius), inPoints))) - - _coords = numpy_support.numpy_to_vtk( - outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT - ) - vtk_coords = vtkPoints() - vtk_coords.SetData(_coords) - outWrap.SetPoints(vtk_coords) + outData.points = outPoints return 1 @@ -157,10 +150,7 @@ def RequestData(self, request, inInfo, outInfo): outData = self.GetOutputData(outInfo, 0) outData.DeepCopy(inData) - inWrap = dsa.WrapDataObject(inData) - outWrap = dsa.WrapDataObject(outData) - - inPoints = np.array(inWrap.Points) + inPoints = inData.points pRadius = (self.radius + 1) if self.isData else self.radius outPoints = np.array(list(map(lambda x: ProcessPoint(x, pRadius), inPoints))) # outPoints = np.array(list(map(ProcessPoint,inPoints))) @@ -170,7 +160,7 @@ def RequestData(self, request, inInfo, outInfo): ) vtk_coords = vtkPoints() vtk_coords.SetData(_coords) - outWrap.SetPoints(vtk_coords) + outData.points = vtk_coords return 1 @@ -255,6 +245,7 @@ def __init__(self): self.__Dims = -1 self.project = 0 self.translate = False + self.cached_points = None def SetTranslation(self, translate): if self.translate != translate: @@ -273,49 +264,53 @@ def RequestData(self, request, inInfo, outInfo): afilter = vtkAppendFilter() afilter.AddInputData(inData) afilter.Update() - outData.DeepCopy(afilter.GetOutput()) + outData.ShallowCopy(afilter.GetOutput()) else: - outData.DeepCopy(inData) - + outData.ShallowCopy(inData) if self.project == 0: return 1 - - inWrap = dsa.WrapDataObject(inData) - outWrap = dsa.WrapDataObject(outData) - inPoints = np.array(inWrap.Points) - - flat = inPoints.flatten() - x = flat[0::3] - 180.0 if self.translate else flat[0::3] - y = flat[1::3] - - try: - # Use proj4 string for WGS84 instead of EPSG code to avoid database dependency - latlon = Proj(proj="latlong", datum="WGS84") - if self.project == 1: - proj = Proj(proj="robin") - elif self.project == 2: - proj = Proj(proj="moll") - else: - # Should not reach here, but return without transformation + if ( + self.cached_points + and self.cached_points.GetMTime() >= inData.GetPoints().GetMTime() + ): + outData.SetPoints(self.cached_points) + else: + # we modify the points, so copy them + out_points_vtk = vtkPoints() + out_points_vtk.DeepCopy(inData.GetPoints()) + outData.SetPoints(out_points_vtk) + out_points_np = outData.points + + flat = out_points_np.flatten() + x = flat[0::3] - 180.0 if self.translate else flat[0::3] + y = flat[1::3] + + try: + # Use proj4 string for WGS84 instead of EPSG code to avoid database dependency + latlon = Proj(proj="latlong", datum="WGS84") + if self.project == 1: + proj = Proj(proj="robin") + elif self.project == 2: + proj = Proj(proj="moll") + else: + # Should not reach here, but return without transformation + return 1 + + xformer = Transformer.from_proj(latlon, proj, always_xy=True) + res = xformer.transform(x, y) + except Exception as e: + print(f"Projection error: {e}") + # If projection fails, return without modifying coordinates return 1 - - xformer = Transformer.from_proj(latlon, proj, always_xy=True) - res = xformer.transform(x, y) - except Exception as e: - print(f"Projection error: {e}") - # If projection fails, return without modifying coordinates - return 1 - flat[0::3] = np.array(res[0]) - flat[1::3] = np.array(res[1]) - - outPoints = flat.reshape(inPoints.shape) - _coords = numpy_support.numpy_to_vtk( - outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT - ) - vtk_coords = vtkPoints() - vtk_coords.SetData(_coords) - outWrap.SetPoints(vtk_coords) - + flat[0::3] = np.array(res[0]) + flat[1::3] = np.array(res[1]) + + outPoints = flat.reshape(out_points_np.shape) + _coords = numpy_support.numpy_to_vtk(outPoints, deep=True) + outData.GetPoints().SetData(_coords) + # the previous cached_points, if any, is available for + # garbage collection after this assignment + self.cached_points = out_points_vtk return 1 @@ -415,15 +410,15 @@ def RequestData(self, request, inInfo, outInfo): @smdomain.datatype(dataTypes=["vtkPolyData"], composite_data_supported=False) @smproperty.xml( """ - + default_values="0 0"> - + default_values="0 0"> """ ) @@ -432,44 +427,95 @@ def __init__(self): super().__init__( nInputPorts=1, nOutputPorts=1, outputType="vtkUnstructuredGrid" ) - self.longrange = [-180.0, 180.0] - self.latrange = [-90.0, 90.0] + self.trim_lon = [0, 0] + self.trim_lat = [0, 0] + self.cached_cell_centers = None + self.cached_ghosts = None - def SetLongitudeRange(self, min, max): - if self.longrange[0] != min or self.longrange[1] != max: - self.longrange = [min, max] + def SetTrimLongitude(self, left, right): + if left < 0 or left > 180 or right < 0 or right > 180: + print_error( + f"SetTrimLongitude called with parameters outside [0, 180]: {left=}, {right=}" + ) + return + if self.trim_lon[0] != left or self.trim_lon[1] != right: + self.trim_lon = [left, right] self.Modified() - def SetLatitudeRange(self, min, max): - if self.latrange[0] != min or self.latrange[1] != max: - self.latrange = [min, max] + def SetTrimLatitude(self, left, right): + if left < 0 or left > 90 or right < 0 or right > 90: + print_error( + f"SetTrimLatitude called with parameters outside [0, 180]: {left=}, {right=}" + ) + return + if self.trim_lat[0] != left or self.trim_lat[1] != right: + self.trim_lat = [left, right] self.Modified() def RequestData(self, request, inInfo, outInfo): inData = self.GetInputData(inInfo, 0, 0) outData = self.GetOutputData(outInfo, 0) - if self.longrange == [-180.0, 180] and self.latrange == [-90, 90]: + if self.trim_lon == [0, 0] and self.trim_lat == [0, 0]: outData.ShallowCopy(inData) return 1 - box = vtkPVBox() - box.SetReferenceBounds( - self.longrange[0], - self.longrange[1], - self.latrange[0], - self.latrange[1], - -1.0, - 1.0, - ) - box.SetUseReferenceBounds(True) - extract = vtkPVClipDataSet() - extract.SetClipFunction(box) - extract.InsideOutOn() - extract.ExactBoxClipOn() - extract.SetInputData(inData) - extract.Update() + outData.ShallowCopy(inData) + if self.cached_cell_centers and self.cached_cell_centers.GetMTime() >= max( + inData.GetPoints().GetMTime(), inData.GetCells().GetMTime() + ): + cell_centers = self.cached_cell_centers + else: + # convert to polydata, as vtkCellCenters only works on polydata + # import pdb;pdb.set_trace() + to_poly = vtkGeometryFilter() + to_poly.SetInputData(inData) + + # get cell centers + compute_centers = vtkCellCenters() + compute_centers.SetInputConnection(to_poly.GetOutputPort()) + compute_centers.Update() + cell_centers = compute_centers.GetOutput().GetPoints().GetData() + # previous cached_cell_centers, if any, + # is available for garbage collection after this assignment + self.cached_cell_centers = cell_centers + + # get the numpy array for cell centers + cc = numpy_support.vtk_to_numpy(cell_centers) + + if self.cached_ghosts and self.cached_ghosts.GetMTime() >= max( + self.GetMTime(), inData.GetPoints().GetMTime(), cell_centers.GetMTime() + ): + ghost = self.cached_ghosts + else: + # import pdb;pdb.set_trace() + # compute the new bounds by trimming the inData bounds + bounds = list(inData.GetBounds()) + bounds[0] = bounds[0] + self.trim_lon[0] + bounds[1] = bounds[1] - self.trim_lon[1] + bounds[2] = bounds[2] + self.trim_lat[0] + bounds[3] = bounds[3] - self.trim_lat[1] + + # add hidden cells based on bounds + outside_mask = ( + (cc[:, 0] < bounds[0]) + | (cc[:, 0] > bounds[1]) + | (cc[:, 1] < bounds[2]) + | (cc[:, 1] > bounds[3]) + ) + + # Create ghost array (0 = visible, HIDDENCELL = invisible) + ghost_np = np.where( + outside_mask, vtkDataSetAttributes.HIDDENCELL, 0 + ).astype(np.uint8) + + # Convert to VTK and add to output + ghost = numpy_support.numpy_to_vtk(ghost_np) + ghost.SetName(vtkDataSetAttributes.GhostArrayName()) + # the previous cached_ghosts, if any, + # is available for garbage collection after this assignment + self.cached_ghosts = ghost + outData.GetCellData().AddArray(ghost) - outData.ShallowCopy(extract.GetOutput()) return 1 @@ -535,14 +581,14 @@ def GetMeridian(self): def RequestData(self, request, inInfo, outInfo): inData = self.GetInputData(inInfo, 0, 0) - inPoints = inData.GetPoints() - inCellArray = inData.GetCells() outData = self.GetOutputData(outInfo, 0) if ( self._cached_output - and self._cached_output.GetMTime() > inPoints.GetMTime() - and self._cached_output.GetMTime() > inCellArray.GetMTime() + and self._cached_output.GetPoints().GetMTime() + >= inData.GetPoints().GetMTime() + and self._cached_output.GetCells().GetMTime() + >= inData.GetCells().GetMTime() ): # only scalars have been added or removed cached_cell_data = self._cached_output.GetCellData() @@ -606,6 +652,7 @@ def RequestData(self, request, inInfo, outInfo): append.AddInputData(transform.GetOutput()) append.Update() outData.ShallowCopy(append.GetOutput()) + # previous _cached_output is available for garbage collection self._cached_output = outData.NewInstance() self._cached_output.ShallowCopy(outData) return 1