Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions benchmarks/mpas_ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,32 @@ def time_kd_tree(self, resolution):
def time_ball_tree(self, resolution):
self.uxds.uxgrid.get_ball_tree()

def time_r_tree(self, resolution):
self.uxds.uxgrid.get_rtree()


class QueryTreeStructures:
def setup(self):
# Load dataset
self.uxds_120 = ux.open_dataset(file_path_dict['120km'][0], file_path_dict['120km'][1])

# Construct Trees
self.kd_tree = self.uxds_120.uxgrid.get_kd_tree(coordinate_system='spherical')
self.ball_tree = self.uxds_120.uxgrid.get_ball_tree(coordinate_system='spherical')
self.r_tree = self.uxds_120.uxgrid.get_rtree()
self.x = self.uxds_120.uxgrid.face_x[0].values
self.y = self.uxds_120.uxgrid.face_y[0].values
self.z = self.uxds_120.uxgrid.face_z[0].values

def time_kd_tree(self):
_, _ = self.kd_tree.query([0.0, 0.0], return_distance=True, k=1)

def time_ball_tree(self):
_, _ = self.ball_tree.query([0.0, 0.0], return_distance=True, k=1)

def time_r_tree(self):
_ = self.r_tree.intersects((self.x, self.y, self.z, self.x, self.y, self.z))


class RemapDownsample:

Expand Down
25 changes: 25 additions & 0 deletions test/test_rtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import uxarray as ux

import numpy.testing as nt
import os
from pathlib import Path

current_path = Path(os.path.dirname(os.path.realpath(__file__)))

CSne30_data_path = current_path / "meshfiles" / "ugrid" / "outCSne30" / "outCSne30_vortex.nc"
quad_hex_grid_path = current_path / "meshfiles" / "ugrid" / "quad-hexagon" / "grid.nc"


def test_quad_hex_face_centers():
"""Tests a face center query into the RTree, which expects the same index
to be returned."""
uxgrid = ux.open_grid(quad_hex_grid_path)
rt = uxgrid.get_rtree()

for i in range(uxgrid.n_face):
x = uxgrid.face_x[i].values
y = uxgrid.face_y[i].values
z = uxgrid.face_z[i].values
res = rt.intersects((x, y, z, x, y, z))
assert len(res) == 1
assert res[0] == i
10 changes: 10 additions & 0 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
_populate_edge_face_distances,
_populate_edge_node_distances,
)
from uxarray.grid.rtree import _construct_rtree
from uxarray.grid.utils import _get_cartesian_face_edge_nodes_array
from uxarray.grid.validation import (
_check_area,
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
self._raster_data_id = None

# initialize cached data structures (nearest neighbor operations)
self._rtree = None
self._ball_tree = None
self._kd_tree = None
self._spatialhash = None
Expand Down Expand Up @@ -1643,6 +1645,14 @@ def chunk(self, n_node="auto", n_edge="auto", n_face="auto"):
else:
setattr(self, var_name, grid_var.chunk())

def get_rtree(self, p: int = 10, page_size: int = 512, reconstruct: bool = False):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name this get_r_tree just to be consistent with how we name our other tree functions?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestions.

"""TODO:"""

if self._rtree is None or reconstruct:
self._rtree = _construct_rtree(self.bounds, p, page_size)

return self._rtree

def get_ball_tree(
self,
coordinates: Optional[str] = "face centers",
Expand Down
74 changes: 74 additions & 0 deletions uxarray/grid/rtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
from numba import njit, prange
from spatialpandas.spatialindex import HilbertRtree


def _construct_rtree(bounds, p=10, page_size=512):
lat_bounds = bounds.sel(lon_lat=0, min_max=[0, 1]).values
lon_bounds = bounds.sel(lon_lat=1, min_max=[0, 1]).values

boxes = face_aabb_xyz(lat_bounds, lon_bounds)

return HilbertRtree(boxes, p, page_size)


@njit(cache=True)
def face_aabb_xyz_kernel(lat0, lat1, lon0, lon1, eps=1e-12):
two_pi = 2 * np.pi

# if it crosses the antimeridian, unwrap lon1
if lon1 < lon0:
lon1 += two_pi

# build list of theta samples: ends and any cardinal meridians inside
samples = [lon0, lon1]
for theta_c in (0.0, np.pi / 2, np.pi, 3 * np.pi / 2):
t = theta_c
if t < lon0:
t += two_pi
if lon0 <= t <= lon1:
samples.append(t)

# build list of phi samples: bounds and equator if spanned
phis = [lat0, lat1]
if lat0 <= 0.0 <= lat1:
phis.append(0.0)

# initialize extremes TODO
xmin = ymin = zmin = 1e20
xmax = ymax = zmax = -1e20

# sample all (phi, theta)
for phi in phis:
sin_phi = np.sin(phi)
cos_phi = np.cos(phi)
for theta in samples:
x = cos_phi * np.cos(theta)
y = cos_phi * np.sin(theta)
z = sin_phi

if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
if z < zmin:
zmin = z
if z > zmax:
zmax = z

return (xmin - eps, ymin - eps, zmin - eps, xmax + eps, ymax + eps, zmax + eps)


@njit(cache=True, parallel=True)
def face_aabb_xyz(lat_bounds, lon_bounds, eps=1e-12):
n = lat_bounds.shape[0]
boxes = np.empty((n, 6), dtype=np.float64)
for i in prange(n):
boxes[i, :] = face_aabb_xyz_kernel(
lat_bounds[i, 0], lat_bounds[i, 1], lon_bounds[i, 0], lon_bounds[i, 1], eps
)
return boxes
Loading