Skip to content

Commit 92c4937

Browse files
Moving ParticleSetView to its own file
1 parent 4b20ad5 commit 92c4937

File tree

4 files changed

+297
-295
lines changed

4 files changed

+297
-295
lines changed

src/parcels/_core/field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_unitconverters_map,
1515
)
1616
from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index
17-
from parcels._core.particle import ParticleSetView
17+
from parcels._core.particlesetview import ParticleSetView
1818
from parcels._core.statuscodes import (
1919
AllParcelsErrorCodes,
2020
StatusCode,

src/parcels/_core/particle.py

Lines changed: 0 additions & 293 deletions
Original file line numberDiff line numberDiff line change
@@ -116,299 +116,6 @@ def add_variable(self, variable: Variable | list[Variable]):
116116
return ParticleClass(variables=self.variables + variable)
117117

118118

119-
class ParticleSetView:
120-
"""Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet."""
121-
122-
def __init__(self, data, index):
123-
self._data = data
124-
self._index = index
125-
126-
def __getattr__(self, name):
127-
# Return a proxy that behaves like the underlying numpy array but
128-
# writes back into the parent arrays when sliced/modified. This
129-
# enables constructs like `particles.dlon[mask] += vals` to update
130-
# the parent arrays rather than temporary copies.
131-
if name in self._data:
132-
# If this ParticleSetView represents a single particle (integer
133-
# index), return the underlying scalar directly to preserve
134-
# user-facing semantics (e.g., `pset[0].time` should be a number).
135-
if isinstance(self._index, (int, np.integer)):
136-
return self._data[name][self._index]
137-
if isinstance(self._index, np.ndarray) and self._index.ndim == 0:
138-
return self._data[name][int(self._index)]
139-
return ParticleSetViewArray(self._data, self._index, name)
140-
return self._data[name][self._index]
141-
142-
def __setattr__(self, name, value):
143-
if name in ["_data", "_index"]:
144-
object.__setattr__(self, name, value)
145-
else:
146-
self._data[name][self._index] = value
147-
148-
def __getitem__(self, index):
149-
# normalize single-element tuple indexing (e.g., (inds,))
150-
if isinstance(index, tuple) and len(index) == 1:
151-
index = index[0]
152-
153-
base = self._index
154-
new_index = np.zeros_like(base, dtype=bool)
155-
156-
# Boolean mask (could be local-length or global-length)
157-
if isinstance(index, (np.ndarray, list)) and np.asarray(index).dtype == bool:
158-
arr = np.asarray(index)
159-
if arr.size == base.size:
160-
# global mask
161-
new_index = arr
162-
elif arr.size == int(np.sum(base)):
163-
new_index[base] = arr
164-
else:
165-
raise ValueError(
166-
f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}"
167-
)
168-
return ParticleSetView(self._data, new_index)
169-
170-
# Integer array/list, slice or single integer relative to the local view
171-
# (boolean masks were handled above). Normalize and map to global
172-
# particle indices for both boolean-base and integer-base `self._index`.
173-
if isinstance(index, (np.ndarray, list, slice, int)):
174-
# convert list/ndarray to ndarray, keep slice/int as-is
175-
idx = np.asarray(index) if isinstance(index, (np.ndarray, list)) else index
176-
if base.dtype == bool:
177-
particle_idxs = np.flatnonzero(base)
178-
sel = particle_idxs[idx]
179-
else:
180-
base_arr = np.asarray(base)
181-
sel = base_arr[idx]
182-
new_index[sel] = True
183-
return ParticleSetView(self._data, new_index)
184-
185-
# Fallback: try to assign directly (preserves previous behaviour for other index types)
186-
try:
187-
new_index[base] = index
188-
return ParticleSetView(self._data, new_index)
189-
except Exception as e:
190-
raise TypeError(f"Unsupported index type for ParticleSetView.__getitem__: {type(index)!r}") from e
191-
192-
def __len__(self):
193-
return len(self._index)
194-
195-
196-
def _unwrap(other):
197-
"""Return ndarray for ParticleSetViewArray or the value unchanged."""
198-
return other.__array__() if isinstance(other, ParticleSetViewArray) else other
199-
200-
201-
def _asarray(other):
202-
"""Return numpy array for ParticleSetViewArray, otherwise return argument."""
203-
return np.asarray(other.__array__()) if isinstance(other, ParticleSetViewArray) else other
204-
205-
206-
class ParticleSetViewArray:
207-
"""Array-like proxy for a ParticleSetView that writes through to the parent arrays when mutated."""
208-
209-
def __init__(self, data, index, name):
210-
self._data = data
211-
self._index = index
212-
self._name = name
213-
214-
def __array__(self, dtype=None):
215-
arr = self._data[self._name][self._index]
216-
return arr.astype(dtype) if dtype is not None else arr
217-
218-
def __repr__(self):
219-
return repr(self.__array__())
220-
221-
def __len__(self):
222-
return len(self.__array__())
223-
224-
def _to_global_index(self, subindex=None):
225-
"""Return a global index (boolean mask or integer indices) that
226-
addresses the parent arrays. If `subindex` is provided it selects
227-
within the current local view and maps back to the global index.
228-
"""
229-
base = self._index
230-
if subindex is None:
231-
return base
232-
233-
# If subindex is a boolean array, support both local-length masks
234-
# (length == base.sum()) and global-length masks (length == base.size).
235-
if isinstance(subindex, (np.ndarray, list)) and np.asarray(subindex).dtype == bool:
236-
arr = np.asarray(subindex)
237-
if arr.size == base.size:
238-
# already a global mask
239-
return arr
240-
if arr.size == int(np.sum(base)):
241-
global_mask = np.zeros_like(base, dtype=bool)
242-
global_mask[base] = arr
243-
return global_mask
244-
raise ValueError(
245-
f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}"
246-
)
247-
248-
# Handle tuple indexing where the first axis indexes particles
249-
# and later axes index into the per-particle array shape (e.g. ei[:, igrid])
250-
if isinstance(subindex, tuple):
251-
first, *rest = subindex
252-
# map the first index (local selection) to global particle indices
253-
if base.dtype == bool:
254-
particle_idxs = np.flatnonzero(base)
255-
first_arr = np.asarray(first) if isinstance(first, (np.ndarray, list)) else first
256-
sel = particle_idxs[first_arr]
257-
else:
258-
base_arr = np.asarray(base)
259-
sel = base_arr[first]
260-
261-
# if rest contains a single int (e.g., column), return tuple index
262-
if len(rest) == 1:
263-
return (sel, rest[0])
264-
# return full tuple (sel, ...) for higher-dim cases
265-
return tuple([sel] + rest)
266-
267-
# If base is a boolean mask over the parent array and subindex is
268-
# an integer or slice relative to the local view, map it to integer
269-
# indices in the parent array.
270-
if base.dtype == bool:
271-
if isinstance(subindex, (slice, int)):
272-
rel = np.flatnonzero(base)[subindex]
273-
return rel
274-
# If subindex is an integer/array selection (relative to the
275-
# local view) map those to global integer indices.
276-
arr = np.asarray(subindex)
277-
if arr.dtype != bool:
278-
particle_idxs = np.flatnonzero(base)
279-
sel = particle_idxs[arr]
280-
return sel
281-
# Otherwise treat subindex as a boolean mask relative to the
282-
# local view and expand to a global boolean mask.
283-
global_mask = np.zeros_like(base, dtype=bool)
284-
global_mask[base] = arr
285-
return global_mask
286-
287-
# If base is an array of integer indices
288-
base_arr = np.asarray(base)
289-
try:
290-
return base_arr[subindex]
291-
except Exception:
292-
return base_arr[np.asarray(subindex, dtype=bool)]
293-
294-
def __getitem__(self, subindex):
295-
# Handle tuple indexing (e.g. [:, igrid]) by applying the tuple
296-
# to the local selection first. This covers the common case
297-
# `particles.ei[:, igrid]` where `ei` is a 2D parent array and the
298-
# second index selects the grid index.
299-
if isinstance(subindex, tuple):
300-
local = self._data[self._name][self._index]
301-
return local[subindex]
302-
303-
new_index = self._to_global_index(subindex)
304-
return ParticleSetViewArray(self._data, new_index, self._name)
305-
306-
def __setitem__(self, subindex, value):
307-
tgt = self._to_global_index(subindex)
308-
self._data[self._name][tgt] = value
309-
310-
# in-place ops must write back into the parent array
311-
def __iadd__(self, other):
312-
vals = self._data[self._name][self._index] + _unwrap(other)
313-
self._data[self._name][self._index] = vals
314-
return self
315-
316-
def __isub__(self, other):
317-
vals = self._data[self._name][self._index] - _unwrap(other)
318-
self._data[self._name][self._index] = vals
319-
return self
320-
321-
def __imul__(self, other):
322-
vals = self._data[self._name][self._index] * _unwrap(other)
323-
self._data[self._name][self._index] = vals
324-
return self
325-
326-
# Provide simple numpy-like evaluation for binary ops by delegating to ndarray
327-
def __add__(self, other):
328-
return self.__array__() + _unwrap(other)
329-
330-
def __sub__(self, other):
331-
return self.__array__() - _unwrap(other)
332-
333-
def __mul__(self, other):
334-
return self.__array__() * _unwrap(other)
335-
336-
def __truediv__(self, other):
337-
return self.__array__() / _unwrap(other)
338-
339-
def __floordiv__(self, other):
340-
return self.__array__() // _unwrap(other)
341-
342-
def __pow__(self, other):
343-
return self.__array__() ** _unwrap(other)
344-
345-
def __neg__(self):
346-
return -self.__array__()
347-
348-
def __pos__(self):
349-
return +self.__array__()
350-
351-
def __abs__(self):
352-
return abs(self.__array__())
353-
354-
# Right-hand operations to handle cases like `scalar - ParticleSetViewArray`
355-
def __radd__(self, other):
356-
return _unwrap(other) + self.__array__()
357-
358-
def __rsub__(self, other):
359-
return _unwrap(other) - self.__array__()
360-
361-
def __rmul__(self, other):
362-
return _unwrap(other) * self.__array__()
363-
364-
def __rtruediv__(self, other):
365-
return _unwrap(other) / self.__array__()
366-
367-
def __rfloordiv__(self, other):
368-
return _unwrap(other) // self.__array__()
369-
370-
def __rpow__(self, other):
371-
return _unwrap(other) ** self.__array__()
372-
373-
# Comparison operators should return plain numpy boolean arrays so that
374-
# expressions like `mask = particles.gridID == gid` produce an ndarray
375-
# usable for indexing (rather than another ParticleSetViewArray).
376-
def __eq__(self, other):
377-
left = np.asarray(self.__array__())
378-
right = _asarray(other)
379-
return left == right
380-
381-
def __ne__(self, other):
382-
left = np.asarray(self.__array__())
383-
right = _asarray(other)
384-
return left != right
385-
386-
def __lt__(self, other):
387-
left = np.asarray(self.__array__())
388-
right = _asarray(other)
389-
return left < right
390-
391-
def __le__(self, other):
392-
left = np.asarray(self.__array__())
393-
right = _asarray(other)
394-
return left <= right
395-
396-
def __gt__(self, other):
397-
left = np.asarray(self.__array__())
398-
right = _asarray(other)
399-
return left > right
400-
401-
def __ge__(self, other):
402-
left = np.asarray(self.__array__())
403-
right = _asarray(other)
404-
return left >= right
405-
406-
# Allow attribute access like .dtype etc. by forwarding to the ndarray
407-
def __getattr__(self, item):
408-
arr = self.__array__()
409-
return getattr(arr, item)
410-
411-
412119
def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]):
413120
existing_names = {var.name for var in existing_vars}
414121
for var in new_vars:

src/parcels/_core/particleset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from parcels._core.converters import _convert_to_flat_array
1313
from parcels._core.kernel import Kernel
14-
from parcels._core.particle import Particle, ParticleSetView, create_particle_data
14+
from parcels._core.particle import Particle, create_particle_data
15+
from parcels._core.particlesetview import ParticleSetView
1516
from parcels._core.statuscodes import StatusCode
1617
from parcels._core.utils.time import (
1718
TimeInterval,

0 commit comments

Comments
 (0)