|
17 | 17 | "UXPiecewiseConstantFace", |
18 | 18 | "UXPiecewiseLinearNode", |
19 | 19 | "XLinear", |
| 20 | + "XNearest", |
20 | 21 | "ZeroInterpolator", |
21 | 22 | ] |
22 | 23 |
|
@@ -111,6 +112,69 @@ def XLinear( |
111 | 112 | return value.compute() if isinstance(value, dask.Array) else value |
112 | 113 |
|
113 | 114 |
|
| 115 | +def XNearest( |
| 116 | + field: Field, |
| 117 | + ti: int, |
| 118 | + position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]], |
| 119 | + tau: np.float32 | np.float64, |
| 120 | + t: np.float32 | np.float64, |
| 121 | + z: np.float32 | np.float64, |
| 122 | + y: np.float32 | np.float64, |
| 123 | + x: np.float32 | np.float64, |
| 124 | +): |
| 125 | + """ |
| 126 | + Nearest-Neighbour spatial interpolation on a regular grid. |
| 127 | + Note that this still uses linear interpolation in time. |
| 128 | + """ |
| 129 | + xi, xsi = position["X"] |
| 130 | + yi, eta = position["Y"] |
| 131 | + zi, zeta = position["Z"] |
| 132 | + |
| 133 | + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) |
| 134 | + data = field.data |
| 135 | + |
| 136 | + lenT = 2 if np.any(tau > 0) else 1 |
| 137 | + |
| 138 | + # Spatial coordinates: left if barycentric < 0.5, otherwise right |
| 139 | + zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1) |
| 140 | + zi_full = np.where(zeta < 0.5, zi, zi_1) |
| 141 | + |
| 142 | + yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1) |
| 143 | + yi_full = np.where(eta < 0.5, yi, yi_1) |
| 144 | + |
| 145 | + xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1) |
| 146 | + xi_full = np.where(xsi < 0.5, xi, xi_1) |
| 147 | + |
| 148 | + # Time coordinates: 1 point at ti, then 1 point at ti+1 |
| 149 | + if lenT == 1: |
| 150 | + ti_full = ti |
| 151 | + else: |
| 152 | + ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) |
| 153 | + ti_full = np.concatenate([ti, ti_1]) |
| 154 | + xi_full = np.repeat(xi_full, 2) |
| 155 | + yi_full = np.repeat(yi_full, 2) |
| 156 | + zi_full = np.repeat(zi_full, 2) |
| 157 | + |
| 158 | + # Create DataArrays for indexing |
| 159 | + selection_dict = { |
| 160 | + axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), |
| 161 | + axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), |
| 162 | + } |
| 163 | + if "Z" in axis_dim: |
| 164 | + selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) |
| 165 | + if "time" in data.dims: |
| 166 | + selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) |
| 167 | + |
| 168 | + corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi)) |
| 169 | + |
| 170 | + if lenT == 2: |
| 171 | + value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau |
| 172 | + else: |
| 173 | + value = corner_data[0, :] |
| 174 | + |
| 175 | + return value.compute() if isinstance(value, dask.Array) else value |
| 176 | + |
| 177 | + |
114 | 178 | def UXPiecewiseConstantFace( |
115 | 179 | field: Field, |
116 | 180 | ti: int, |
|
0 commit comments