Skip to content
237 changes: 157 additions & 80 deletions src/e3sm_quickview/plugins/eam_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,27 @@
vtkPoints,
)
from vtkmodules.vtkCommonDataModel import (
vtkPolyData,
vtkCellArray,
vtkDataSetAttributes,
vtkPlane,
vtkPolyData,
)
from vtkmodules.vtkCommonTransforms import vtkTransform
from vtkmodules.vtkFiltersCore import (
vtkAppendFilter,
vtkCellCenters,
vtkGenerateIds,
vtkPolyDataToUnstructuredGrid,
)
from vtkmodules.vtkFiltersGeneral import (
vtkTransformFilter,
vtkTableBasedClipDataSet,
vtkTransformFilter,
)
from vtkmodules.vtkFiltersPoints import (
vtkExtractSurface
)
from vtkmodules.vtkFiltersGeometry import (
vtkGeometryFilter
)

try:
Expand Down Expand Up @@ -255,6 +264,11 @@ def __init__(self):
self.__Dims = -1
self.project = 0
self.translate = False
self.cached_points = None

def __del__(self):
if self.cached_points:
self.cached_points.Unregister()
Comment thread
danlipsa marked this conversation as resolved.
Outdated

def SetTranslation(self, translate):
if self.translate != translate:
Expand All @@ -267,55 +281,58 @@ def SetProjection(self, project):
self.Modified()

def RequestData(self, request, inInfo, outInfo):
if self.project == 0:
return 1
inData = self.GetInputData(inInfo, 0, 0)
outData = self.GetOutputData(outInfo, 0)
if inData.IsA("vtkPolyData"):
afilter = vtkAppendFilter()
afilter.AddInputData(inData)
afilter.Update()
outData.DeepCopy(afilter.GetOutput())
outData.ShallowCopy(afilter.GetOutput())
else:
outData.DeepCopy(inData)

if self.project == 0:
return 1

inWrap = dsa.WrapDataObject(inData)
outWrap = dsa.WrapDataObject(outData)
inPoints = np.array(inWrap.Points)

flat = inPoints.flatten()
x = flat[0::3] - 180.0 if self.translate else flat[0::3]
y = flat[1::3]

try:
# Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
latlon = Proj(proj="latlong", datum="WGS84")
if self.project == 1:
proj = Proj(proj="robin")
elif self.project == 2:
proj = Proj(proj="moll")
else:
# Should not reach here, but return without transformation
outData.ShallowCopy(inData)
if self.cached_points and \
self.cached_points.GetMTime() >= inData.GetPoints().GetMTime():
outData.SetPoints(self.cached_points)
else:
# we modify the points, so copy them
out_points_vtk = vtkPoints()
out_points_vtk.DeepCopy(inData.GetPoints())
outData.SetPoints(out_points_vtk)
out_points_np = outData.points

flat = out_points_np.flatten()
x = flat[0::3] - 180.0 if self.translate else flat[0::3]
y = flat[1::3]

try:
# Use proj4 string for WGS84 instead of EPSG code to avoid database dependency
latlon = Proj(proj="latlong", datum="WGS84")
if self.project == 1:
proj = Proj(proj="robin")
elif self.project == 2:
proj = Proj(proj="moll")
else:
# Should not reach here, but return without transformation
return 1

xformer = Transformer.from_proj(latlon, proj, always_xy=True)
res = xformer.transform(x, y)
except Exception as e:
print(f"Projection error: {e}")
# If projection fails, return without modifying coordinates
return 1

xformer = Transformer.from_proj(latlon, proj, always_xy=True)
res = xformer.transform(x, y)
except Exception as e:
print(f"Projection error: {e}")
# If projection fails, return without modifying coordinates
return 1
flat[0::3] = np.array(res[0])
flat[1::3] = np.array(res[1])

outPoints = flat.reshape(inPoints.shape)
_coords = numpy_support.numpy_to_vtk(
outPoints, deep=True, array_type=vtkConstants.VTK_FLOAT
)
vtk_coords = vtkPoints()
vtk_coords.SetData(_coords)
outWrap.SetPoints(vtk_coords)

flat[0::3] = np.array(res[0])
flat[1::3] = np.array(res[1])

outPoints = flat.reshape(out_points_np.shape)
_coords = numpy_support.numpy_to_vtk(outPoints, deep=True)
outData.GetPoints().SetData(_coords)
Comment thread
danlipsa marked this conversation as resolved.
if self.cached_points:
self.cached_points.Unregister(self)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than Unregister, should you just set self.cached_points = None?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to do this at all? If the python object is deleted, the c++ object should be unregistered as well, isn't it?

self.cached_points = out_points_vtk
self.cached_points.Register(self)
Comment thread
danlipsa marked this conversation as resolved.
Outdated
return 1


Expand Down Expand Up @@ -415,15 +432,15 @@ def RequestData(self, request, inInfo, outInfo):
@smdomain.datatype(dataTypes=["vtkPolyData"], composite_data_supported=False)
@smproperty.xml(
"""
<DoubleVectorProperty name="Longitude Range"
command="SetLongitudeRange"
<DoubleVectorProperty name="Trim Longitude"
Comment thread
danlipsa marked this conversation as resolved.
command="SetTrimLongitude"
number_of_elements="2"
default_values="-180 180">
default_values="0 0">
</DoubleVectorProperty>
<DoubleVectorProperty name="Latitude Range"
command="SetLatitudeRange"
<DoubleVectorProperty name="Trim Latitude"
command="SetTrimLatitude"
number_of_elements="2"
default_values="-90 90">
default_values="0 0">
</DoubleVectorProperty>
"""
)
Expand All @@ -432,44 +449,99 @@ def __init__(self):
super().__init__(
nInputPorts=1, nOutputPorts=1, outputType="vtkUnstructuredGrid"
)
self.longrange = [-180.0, 180.0]
self.latrange = [-90.0, 90.0]

def SetLongitudeRange(self, min, max):
if self.longrange[0] != min or self.longrange[1] != max:
self.longrange = [min, max]
self.trim_lon = [0, 0]
self.trim_lat = [0, 0]
self.cached_cell_centers = None
self.cached_ghosts = None

def __del__(self):
if self.cached_cell_centers:
self.cached_cell_centers.Unregister(self)
Comment thread
danlipsa marked this conversation as resolved.
Outdated
if self.cached_ghosts:
self.cached_ghosts.Unregister(self)

def SetTrimLongitude(self, left, right):
if left < 0 or left > 180 or right < 0 or right > 180:
print_error(f"SetTrimLongitude called with parameters outside [0, 180]: {left=}, {right=}")
return
if self.trim_lon[0] != left or self.trim_lon[1] != right:
self.trim_lon = [left, right]
self.Modified()

def SetLatitudeRange(self, min, max):
if self.latrange[0] != min or self.latrange[1] != max:
self.latrange = [min, max]
def SetTrimLatitude(self, left, right):
if left < 0 or left > 90 or right < 0 or right > 90:
print_error(f"SetTrimLatitude called with parameters outside [0, 180]: {left=}, {right=}")
return
if self.trim_lat[0] != left or self.trim_lat[1] != right:
self.trim_lat = [left, right]
self.Modified()

def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
outData = self.GetOutputData(outInfo, 0)
if self.longrange == [-180.0, 180] and self.latrange == [-90, 90]:
if self.trim_lon == [0, 0] and self.trim_lat == [0, 0]:
outData.ShallowCopy(inData)
return 1

box = vtkPVBox()
box.SetReferenceBounds(
self.longrange[0],
self.longrange[1],
self.latrange[0],
self.latrange[1],
-1.0,
1.0,
)
box.SetUseReferenceBounds(True)
extract = vtkPVClipDataSet()
extract.SetClipFunction(box)
extract.InsideOutOn()
extract.ExactBoxClipOn()
extract.SetInputData(inData)
extract.Update()
outData.ShallowCopy(inData)
if self.cached_cell_centers and self.cached_cell_centers.GetMTime() >= max(
inData.GetPoints().GetMTime(), inData.GetCells().GetMTime()
):
cell_centers = self.cached_cell_centers
else:
# convert to polydata, as vtkCellCenters only works on polydata
# import pdb;pdb.set_trace()
to_poly = vtkGeometryFilter()
to_poly.SetInputData(inData)

# get cell centers
compute_centers = vtkCellCenters()
compute_centers.SetInputConnection(to_poly.GetOutputPort())
compute_centers.Update()
cell_centers = compute_centers.GetOutput().GetPoints().GetData()
if self.cached_cell_centers:
self.cached_cell_centers.Unregister(self)
self.cached_cell_centers = cell_centers
self.cached_cell_centers.Register(self)

# get the numpy array for cell centers
cc = numpy_support.vtk_to_numpy(cell_centers)

if self.cached_ghosts and self.cached_ghosts.GetMTime() >= max(
self.GetMTime(), inData.GetPoints().GetMTime(), cell_centers.GetMTime()
):
ghost = self.cached_ghosts
else:
# import pdb;pdb.set_trace()
# compute the new bounds by trimming the inData bounds
bounds = list(inData.GetBounds())
bounds[0] = bounds[0] + self.trim_lon[0]
bounds[1] = bounds[1] - self.trim_lon[1]
bounds[2] = bounds[2] + self.trim_lat[0]
bounds[3] = bounds[3] - self.trim_lat[1]

# add hidden cells based on bounds
outside_mask = (
(cc[:, 0] < bounds[0])
| (cc[:, 0] > bounds[1])
| (cc[:, 1] < bounds[2])
| (cc[:, 1] > bounds[3])
)

# Create ghost array (0 = visible, HIDDENCELL = invisible)
ghost_np = np.where(
outside_mask, vtkDataSetAttributes.HIDDENCELL, 0
).astype(np.uint8)

# Convert to VTK and add to output
ghost = numpy_support.numpy_to_vtk(ghost_np)
ghost.SetName(vtkDataSetAttributes.GhostArrayName())
if self.cached_ghosts:
self.cached_ghosts.Unregister(self)
self.cached_ghosts = ghost
self.cached_ghosts.Register(self)
outData.GetCellData().AddArray(ghost)

outData.ShallowCopy(extract.GetOutput())
return 1


Expand Down Expand Up @@ -513,6 +585,10 @@ def __init__(self):
self._center_meridian = 0
self._cached_output = None

def __del__(self):
if self._cached_output:
self._cached_output.Unregister(self)

def SetMeridian(self, meridian_):
"""
Specifies the central meridian (longitude in the middle of the map)
Expand All @@ -535,14 +611,12 @@ def GetMeridian(self):

def RequestData(self, request, inInfo, outInfo):
inData = self.GetInputData(inInfo, 0, 0)
inPoints = inData.GetPoints()
inCellArray = inData.GetCells()

outData = self.GetOutputData(outInfo, 0)
if (
self._cached_output
and self._cached_output.GetMTime() > inPoints.GetMTime()
and self._cached_output.GetMTime() > inCellArray.GetMTime()
and self._cached_output.GetPoints().GetMTime() >= inData.GetPoints().GetMTime()
and self._cached_output.GetCells().GetMTime() >= inData.GetCells().GetMTime()
):
# only scalars have been added or removed
cached_cell_data = self._cached_output.GetCellData()
Expand Down Expand Up @@ -606,6 +680,9 @@ def RequestData(self, request, inInfo, outInfo):
append.AddInputData(transform.GetOutput())
append.Update()
outData.ShallowCopy(append.GetOutput())
if self._cached_output:
self._cached_output.Unregister(self)
self._cached_output = outData.NewInstance()
self._cached_output.ShallowCopy(outData)
self._cached_output.Register(self)
return 1
Loading