@@ -116,299 +116,6 @@ def add_variable(self, variable: Variable | list[Variable]):
116116 return ParticleClass (variables = self .variables + variable )
117117
118118
119- class ParticleSetView :
120- """Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet."""
121-
122- def __init__ (self , data , index ):
123- self ._data = data
124- self ._index = index
125-
126- def __getattr__ (self , name ):
127- # Return a proxy that behaves like the underlying numpy array but
128- # writes back into the parent arrays when sliced/modified. This
129- # enables constructs like `particles.dlon[mask] += vals` to update
130- # the parent arrays rather than temporary copies.
131- if name in self ._data :
132- # If this ParticleSetView represents a single particle (integer
133- # index), return the underlying scalar directly to preserve
134- # user-facing semantics (e.g., `pset[0].time` should be a number).
135- if isinstance (self ._index , (int , np .integer )):
136- return self ._data [name ][self ._index ]
137- if isinstance (self ._index , np .ndarray ) and self ._index .ndim == 0 :
138- return self ._data [name ][int (self ._index )]
139- return ParticleSetViewArray (self ._data , self ._index , name )
140- return self ._data [name ][self ._index ]
141-
142- def __setattr__ (self , name , value ):
143- if name in ["_data" , "_index" ]:
144- object .__setattr__ (self , name , value )
145- else :
146- self ._data [name ][self ._index ] = value
147-
148- def __getitem__ (self , index ):
149- # normalize single-element tuple indexing (e.g., (inds,))
150- if isinstance (index , tuple ) and len (index ) == 1 :
151- index = index [0 ]
152-
153- base = self ._index
154- new_index = np .zeros_like (base , dtype = bool )
155-
156- # Boolean mask (could be local-length or global-length)
157- if isinstance (index , (np .ndarray , list )) and np .asarray (index ).dtype == bool :
158- arr = np .asarray (index )
159- if arr .size == base .size :
160- # global mask
161- new_index = arr
162- elif arr .size == int (np .sum (base )):
163- new_index [base ] = arr
164- else :
165- raise ValueError (
166- f"Boolean index has incompatible length { arr .size } for selection of size { int (np .sum (base ))} "
167- )
168- return ParticleSetView (self ._data , new_index )
169-
170- # Integer array/list, slice or single integer relative to the local view
171- # (boolean masks were handled above). Normalize and map to global
172- # particle indices for both boolean-base and integer-base `self._index`.
173- if isinstance (index , (np .ndarray , list , slice , int )):
174- # convert list/ndarray to ndarray, keep slice/int as-is
175- idx = np .asarray (index ) if isinstance (index , (np .ndarray , list )) else index
176- if base .dtype == bool :
177- particle_idxs = np .flatnonzero (base )
178- sel = particle_idxs [idx ]
179- else :
180- base_arr = np .asarray (base )
181- sel = base_arr [idx ]
182- new_index [sel ] = True
183- return ParticleSetView (self ._data , new_index )
184-
185- # Fallback: try to assign directly (preserves previous behaviour for other index types)
186- try :
187- new_index [base ] = index
188- return ParticleSetView (self ._data , new_index )
189- except Exception as e :
190- raise TypeError (f"Unsupported index type for ParticleSetView.__getitem__: { type (index )!r} " ) from e
191-
192- def __len__ (self ):
193- return len (self ._index )
194-
195-
196- def _unwrap (other ):
197- """Return ndarray for ParticleSetViewArray or the value unchanged."""
198- return other .__array__ () if isinstance (other , ParticleSetViewArray ) else other
199-
200-
201- def _asarray (other ):
202- """Return numpy array for ParticleSetViewArray, otherwise return argument."""
203- return np .asarray (other .__array__ ()) if isinstance (other , ParticleSetViewArray ) else other
204-
205-
206- class ParticleSetViewArray :
207- """Array-like proxy for a ParticleSetView that writes through to the parent arrays when mutated."""
208-
209- def __init__ (self , data , index , name ):
210- self ._data = data
211- self ._index = index
212- self ._name = name
213-
214- def __array__ (self , dtype = None ):
215- arr = self ._data [self ._name ][self ._index ]
216- return arr .astype (dtype ) if dtype is not None else arr
217-
218- def __repr__ (self ):
219- return repr (self .__array__ ())
220-
221- def __len__ (self ):
222- return len (self .__array__ ())
223-
224- def _to_global_index (self , subindex = None ):
225- """Return a global index (boolean mask or integer indices) that
226- addresses the parent arrays. If `subindex` is provided it selects
227- within the current local view and maps back to the global index.
228- """
229- base = self ._index
230- if subindex is None :
231- return base
232-
233- # If subindex is a boolean array, support both local-length masks
234- # (length == base.sum()) and global-length masks (length == base.size).
235- if isinstance (subindex , (np .ndarray , list )) and np .asarray (subindex ).dtype == bool :
236- arr = np .asarray (subindex )
237- if arr .size == base .size :
238- # already a global mask
239- return arr
240- if arr .size == int (np .sum (base )):
241- global_mask = np .zeros_like (base , dtype = bool )
242- global_mask [base ] = arr
243- return global_mask
244- raise ValueError (
245- f"Boolean index has incompatible length { arr .size } for selection of size { int (np .sum (base ))} "
246- )
247-
248- # Handle tuple indexing where the first axis indexes particles
249- # and later axes index into the per-particle array shape (e.g. ei[:, igrid])
250- if isinstance (subindex , tuple ):
251- first , * rest = subindex
252- # map the first index (local selection) to global particle indices
253- if base .dtype == bool :
254- particle_idxs = np .flatnonzero (base )
255- first_arr = np .asarray (first ) if isinstance (first , (np .ndarray , list )) else first
256- sel = particle_idxs [first_arr ]
257- else :
258- base_arr = np .asarray (base )
259- sel = base_arr [first ]
260-
261- # if rest contains a single int (e.g., column), return tuple index
262- if len (rest ) == 1 :
263- return (sel , rest [0 ])
264- # return full tuple (sel, ...) for higher-dim cases
265- return tuple ([sel ] + rest )
266-
267- # If base is a boolean mask over the parent array and subindex is
268- # an integer or slice relative to the local view, map it to integer
269- # indices in the parent array.
270- if base .dtype == bool :
271- if isinstance (subindex , (slice , int )):
272- rel = np .flatnonzero (base )[subindex ]
273- return rel
274- # If subindex is an integer/array selection (relative to the
275- # local view) map those to global integer indices.
276- arr = np .asarray (subindex )
277- if arr .dtype != bool :
278- particle_idxs = np .flatnonzero (base )
279- sel = particle_idxs [arr ]
280- return sel
281- # Otherwise treat subindex as a boolean mask relative to the
282- # local view and expand to a global boolean mask.
283- global_mask = np .zeros_like (base , dtype = bool )
284- global_mask [base ] = arr
285- return global_mask
286-
287- # If base is an array of integer indices
288- base_arr = np .asarray (base )
289- try :
290- return base_arr [subindex ]
291- except Exception :
292- return base_arr [np .asarray (subindex , dtype = bool )]
293-
294- def __getitem__ (self , subindex ):
295- # Handle tuple indexing (e.g. [:, igrid]) by applying the tuple
296- # to the local selection first. This covers the common case
297- # `particles.ei[:, igrid]` where `ei` is a 2D parent array and the
298- # second index selects the grid index.
299- if isinstance (subindex , tuple ):
300- local = self ._data [self ._name ][self ._index ]
301- return local [subindex ]
302-
303- new_index = self ._to_global_index (subindex )
304- return ParticleSetViewArray (self ._data , new_index , self ._name )
305-
306- def __setitem__ (self , subindex , value ):
307- tgt = self ._to_global_index (subindex )
308- self ._data [self ._name ][tgt ] = value
309-
310- # in-place ops must write back into the parent array
311- def __iadd__ (self , other ):
312- vals = self ._data [self ._name ][self ._index ] + _unwrap (other )
313- self ._data [self ._name ][self ._index ] = vals
314- return self
315-
316- def __isub__ (self , other ):
317- vals = self ._data [self ._name ][self ._index ] - _unwrap (other )
318- self ._data [self ._name ][self ._index ] = vals
319- return self
320-
321- def __imul__ (self , other ):
322- vals = self ._data [self ._name ][self ._index ] * _unwrap (other )
323- self ._data [self ._name ][self ._index ] = vals
324- return self
325-
326- # Provide simple numpy-like evaluation for binary ops by delegating to ndarray
327- def __add__ (self , other ):
328- return self .__array__ () + _unwrap (other )
329-
330- def __sub__ (self , other ):
331- return self .__array__ () - _unwrap (other )
332-
333- def __mul__ (self , other ):
334- return self .__array__ () * _unwrap (other )
335-
336- def __truediv__ (self , other ):
337- return self .__array__ () / _unwrap (other )
338-
339- def __floordiv__ (self , other ):
340- return self .__array__ () // _unwrap (other )
341-
342- def __pow__ (self , other ):
343- return self .__array__ () ** _unwrap (other )
344-
345- def __neg__ (self ):
346- return - self .__array__ ()
347-
348- def __pos__ (self ):
349- return + self .__array__ ()
350-
351- def __abs__ (self ):
352- return abs (self .__array__ ())
353-
354- # Right-hand operations to handle cases like `scalar - ParticleSetViewArray`
355- def __radd__ (self , other ):
356- return _unwrap (other ) + self .__array__ ()
357-
358- def __rsub__ (self , other ):
359- return _unwrap (other ) - self .__array__ ()
360-
361- def __rmul__ (self , other ):
362- return _unwrap (other ) * self .__array__ ()
363-
364- def __rtruediv__ (self , other ):
365- return _unwrap (other ) / self .__array__ ()
366-
367- def __rfloordiv__ (self , other ):
368- return _unwrap (other ) // self .__array__ ()
369-
370- def __rpow__ (self , other ):
371- return _unwrap (other ) ** self .__array__ ()
372-
373- # Comparison operators should return plain numpy boolean arrays so that
374- # expressions like `mask = particles.gridID == gid` produce an ndarray
375- # usable for indexing (rather than another ParticleSetViewArray).
376- def __eq__ (self , other ):
377- left = np .asarray (self .__array__ ())
378- right = _asarray (other )
379- return left == right
380-
381- def __ne__ (self , other ):
382- left = np .asarray (self .__array__ ())
383- right = _asarray (other )
384- return left != right
385-
386- def __lt__ (self , other ):
387- left = np .asarray (self .__array__ ())
388- right = _asarray (other )
389- return left < right
390-
391- def __le__ (self , other ):
392- left = np .asarray (self .__array__ ())
393- right = _asarray (other )
394- return left <= right
395-
396- def __gt__ (self , other ):
397- left = np .asarray (self .__array__ ())
398- right = _asarray (other )
399- return left > right
400-
401- def __ge__ (self , other ):
402- left = np .asarray (self .__array__ ())
403- right = _asarray (other )
404- return left >= right
405-
406- # Allow attribute access like .dtype etc. by forwarding to the ndarray
407- def __getattr__ (self , item ):
408- arr = self .__array__ ()
409- return getattr (arr , item )
410-
411-
412119def _assert_no_duplicate_variable_names (* , existing_vars : list [Variable ], new_vars : list [Variable ]):
413120 existing_names = {var .name for var in existing_vars }
414121 for var in new_vars :
0 commit comments