Skip to content

Commit ca4df74

Browse files
Merge pull request #2443 from Parcels-code/fix_kernelparticle_bug
Implementing View for KernelParticle/ParticleSet
2 parents d0cb317 + 00db22a commit ca4df74

File tree

9 files changed

+538
-118
lines changed

9 files changed

+538
-118
lines changed

docs/user_guide/examples/tutorial_Argofloats.ipynb

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,70 +26,62 @@
2626
"source": [
2727
"import numpy as np\n",
2828
"\n",
29-
"# Define the new Kernels that mimic Argo vertical movement\n",
29+
"# Define the new Kernel that mimics Argo vertical movement\n",
3030
"driftdepth = 1000 # maximum depth in m\n",
3131
"maxdepth = 2000 # maximum depth in m\n",
3232
"vertical_speed = 0.10 # sink and rise speed in m/s\n",
3333
"cycletime = 10 * 86400 # total time of cycle in seconds\n",
3434
"drifttime = 9 * 86400 # time of deep drift in seconds\n",
3535
"\n",
3636
"\n",
37-
"def ArgoPhase1(particles, fieldset):\n",
38-
" def SinkingPhase(p):\n",
39-
" \"\"\"Phase 0: Sinking with vertical_speed until depth is driftdepth\"\"\"\n",
40-
" p.dz += vertical_speed * particles.dt\n",
41-
" p.cycle_phase = np.where(p.z + p.dz >= driftdepth, 1, p.cycle_phase)\n",
42-
" p.dz = np.where(p.z + p.dz >= driftdepth, driftdepth - p.z, p.dz)\n",
37+
"def ArgoVerticalMovement(particles, fieldset):\n",
38+
" # Split particles based on their current cycle_phase\n",
39+
" ptcls0 = particles[particles.cycle_phase == 0]\n",
40+
" ptcls1 = particles[particles.cycle_phase == 1]\n",
41+
" ptcls2 = particles[particles.cycle_phase == 2]\n",
42+
" ptcls3 = particles[particles.cycle_phase == 3]\n",
43+
" ptcls4 = particles[particles.cycle_phase == 4]\n",
44+
"\n",
45+
" # Phase 0: Sinking with vertical_speed until depth is driftdepth\n",
46+
" ptcls0.dz += vertical_speed * ptcls0.dt\n",
47+
" ptcls0.cycle_phase = np.where(\n",
48+
" ptcls0.z + ptcls0.dz >= driftdepth, 1, ptcls0.cycle_phase\n",
49+
" )\n",
50+
" ptcls0.dz = np.where(\n",
51+
" ptcls0.z + ptcls0.dz >= driftdepth, driftdepth - ptcls0.z, ptcls0.dz\n",
52+
" )\n",
53+
"\n",
54+
" # Phase 1: Drifting at depth for drifttime seconds\n",
55+
" ptcls1.drift_age += ptcls1.dt\n",
56+
" ptcls1.cycle_phase = np.where(ptcls1.drift_age >= drifttime, 2, ptcls1.cycle_phase)\n",
57+
" ptcls1.drift_age = np.where(ptcls1.drift_age >= drifttime, 0, ptcls1.drift_age)\n",
58+
"\n",
59+
" # Phase 2: Sinking further to maxdepth\n",
60+
" ptcls2.dz += vertical_speed * ptcls2.dt\n",
61+
" ptcls2.cycle_phase = np.where(\n",
62+
" ptcls2.z + ptcls2.dz >= maxdepth, 3, ptcls2.cycle_phase\n",
63+
" )\n",
64+
" ptcls2.dz = np.where(\n",
65+
" ptcls2.z + ptcls2.dz >= maxdepth, maxdepth - ptcls2.z, ptcls2.dz\n",
66+
" )\n",
67+
"\n",
68+
" # Phase 3: Rising with vertical_speed until at surface\n",
69+
" ptcls3.dz -= vertical_speed * ptcls3.dt\n",
70+
" ptcls3.temp = fieldset.thetao[ptcls3.time, ptcls3.z, ptcls3.lat, ptcls3.lon]\n",
71+
" ptcls3.cycle_phase = np.where(\n",
72+
" ptcls3.z + ptcls3.dz <= fieldset.mindepth, 4, ptcls3.cycle_phase\n",
73+
" )\n",
74+
" ptcls3.dz = np.where(\n",
75+
" ptcls3.z + ptcls3.dz <= fieldset.mindepth,\n",
76+
" fieldset.mindepth - ptcls3.z,\n",
77+
" ptcls3.dz,\n",
78+
" )\n",
79+
"\n",
80+
" # Phase 4: Transmitting at surface until cycletime is reached\n",
81+
" ptcls4.cycle_phase = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_phase)\n",
82+
" ptcls4.cycle_age = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_age)\n",
83+
" ptcls4.temp = np.nan # no temperature measurement when at surface\n",
4384
"\n",
44-
" SinkingPhase(particles[particles.cycle_phase == 0])\n",
45-
"\n",
46-
"\n",
47-
"def ArgoPhase2(particles, fieldset):\n",
48-
" def DriftingPhase(p):\n",
49-
" \"\"\"Phase 1: Drifting at depth for drifttime seconds\"\"\"\n",
50-
" p.drift_age += particles.dt\n",
51-
" p.cycle_phase = np.where(p.drift_age >= drifttime, 2, p.cycle_phase)\n",
52-
" p.drift_age = np.where(p.drift_age >= drifttime, 0, p.drift_age)\n",
53-
"\n",
54-
" DriftingPhase(particles[particles.cycle_phase == 1])\n",
55-
"\n",
56-
"\n",
57-
"def ArgoPhase3(particles, fieldset):\n",
58-
" def SecondSinkingPhase(p):\n",
59-
" \"\"\"Phase 2: Sinking further to maxdepth\"\"\"\n",
60-
" p.dz += vertical_speed * particles.dt\n",
61-
" p.cycle_phase = np.where(p.z + p.dz >= maxdepth, 3, p.cycle_phase)\n",
62-
" p.dz = np.where(p.z + p.dz >= maxdepth, maxdepth - p.z, p.dz)\n",
63-
"\n",
64-
" SecondSinkingPhase(particles[particles.cycle_phase == 2])\n",
65-
"\n",
66-
"\n",
67-
"def ArgoPhase4(particles, fieldset):\n",
68-
" def RisingPhase(p):\n",
69-
" \"\"\"Phase 3: Rising with vertical_speed until at surface\"\"\"\n",
70-
" p.dz -= vertical_speed * particles.dt\n",
71-
" p.temp = fieldset.thetao[p.time, p.z, p.lat, p.lon]\n",
72-
" p.cycle_phase = np.where(p.z + p.dz <= fieldset.mindepth, 4, p.cycle_phase)\n",
73-
" p.dz = np.where(\n",
74-
" p.z + p.dz <= fieldset.mindepth,\n",
75-
" fieldset.mindepth - p.z,\n",
76-
" p.dz,\n",
77-
" )\n",
78-
"\n",
79-
" RisingPhase(particles[particles.cycle_phase == 3])\n",
80-
"\n",
81-
"\n",
82-
"def ArgoPhase5(particles, fieldset):\n",
83-
" def TransmittingPhase(p):\n",
84-
" \"\"\"Phase 4: Transmitting at surface until cycletime is reached\"\"\"\n",
85-
" p.cycle_phase = np.where(p.cycle_age >= cycletime, 0, p.cycle_phase)\n",
86-
" p.cycle_age = np.where(p.cycle_age >= cycletime, 0, p.cycle_age)\n",
87-
" p.temp = np.nan # no temperature measurement when at surface\n",
88-
"\n",
89-
" TransmittingPhase(particles[particles.cycle_phase == 4])\n",
90-
"\n",
91-
"\n",
92-
"def ArgoPhase6(particles, fieldset):\n",
9385
" particles.cycle_age += particles.dt # update cycle_age"
9486
]
9587
},
@@ -136,9 +128,7 @@
136128
"ArgoParticle = parcels.Particle.add_variable(\n",
137129
" [\n",
138130
" parcels.Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n",
139-
" parcels.Variable(\n",
140-
" \"cycle_age\", dtype=np.float32, initial=0.0\n",
141-
" ), # TODO update to \"timedelta64[s]\"\n",
131+
" parcels.Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n",
142132
" parcels.Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n",
143133
" parcels.Variable(\"temp\", dtype=np.float32, initial=np.nan),\n",
144134
" ]\n",
@@ -155,12 +145,7 @@
155145
"\n",
156146
"# combine Argo vertical movement kernel with built-in Advection kernel\n",
157147
"kernels = [\n",
158-
" ArgoPhase1,\n",
159-
" ArgoPhase2,\n",
160-
" ArgoPhase3,\n",
161-
" ArgoPhase4,\n",
162-
" ArgoPhase5,\n",
163-
" ArgoPhase6,\n",
148+
" ArgoVerticalMovement,\n",
164149
" parcels.kernels.AdvectionRK4,\n",
165150
"]\n",
166151
"\n",

docs/user_guide/examples/tutorial_interaction.ipynb

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,9 @@
293293
" larger_idx = np.where(mass_j > mass_i, pair_j, pair_i)\n",
294294
" smaller_idx = np.where(mass_j > mass_i, pair_i, pair_j)\n",
295295
"\n",
296-
" # perform transfer and mark deletions\n",
297-
" # TODO note that we use temporary arrays for indexing because of KernelParticle bug (GH #2143)\n",
298-
" masses = particles.mass\n",
299-
" states = particles.state\n",
300-
"\n",
301296
" # transfer mass from smaller to larger and mark smaller for deletion\n",
302-
" masses[larger_idx] += particles.mass[smaller_idx]\n",
303-
" states[smaller_idx] = parcels.StatusCode.Delete\n",
304-
"\n",
305-
" # TODO use particle variables directly after KernelParticle bug (GH #2143) is fixed\n",
306-
" particles.mass = masses\n",
307-
" particles.state = states"
297+
" particles.mass[larger_idx] += particles.mass[smaller_idx]\n",
298+
" particles.state[smaller_idx] = parcels.StatusCode.Delete"
308299
]
309300
},
310301
{

src/parcels/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Variable,
1818
Particle,
1919
ParticleClass,
20-
KernelParticle, # ? remove?
2120
)
2221
from parcels._core.field import Field, VectorField
2322
from parcels._core.basegrid import BaseGrid
@@ -87,8 +86,6 @@
8786
"logger",
8887
"download_example_dataset",
8988
"list_example_datasets",
90-
# (marked for potential removal)
91-
"KernelParticle",
9289
]
9390

9491
_stdlib_warnings.warn(

src/parcels/_core/field.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_unitconverters_map,
1515
)
1616
from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index
17-
from parcels._core.particle import KernelParticle
17+
from parcels._core.particlesetview import ParticleSetView
1818
from parcels._core.statuscodes import (
1919
AllParcelsErrorCodes,
2020
StatusCode,
@@ -35,9 +35,9 @@
3535

3636

3737
def _deal_with_errors(error, key, vector_type: VectorType):
38-
if isinstance(key, KernelParticle):
38+
if isinstance(key, ParticleSetView):
3939
key.state = AllParcelsErrorCodes[type(error)]
40-
elif isinstance(key[-1], KernelParticle):
40+
elif isinstance(key[-1], ParticleSetView):
4141
key[-1].state = AllParcelsErrorCodes[type(error)]
4242
else:
4343
raise RuntimeError(f"{error}. Error could not be handled because particles was not part of the Field Sampling.")
@@ -229,7 +229,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
229229
def __getitem__(self, key):
230230
self._check_velocitysampling()
231231
try:
232-
if isinstance(key, KernelParticle):
232+
if isinstance(key, ParticleSetView):
233233
return self.eval(key.time, key.z, key.lat, key.lon, key)
234234
else:
235235
return self.eval(*key)
@@ -330,7 +330,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
330330

331331
def __getitem__(self, key):
332332
try:
333-
if isinstance(key, KernelParticle):
333+
if isinstance(key, ParticleSetView):
334334
return self.eval(key.time, key.z, key.lat, key.lon, key)
335335
else:
336336
return self.eval(*key)

src/parcels/_core/particle.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from parcels._core.utils.time import TimeInterval
1212
from parcels._reprs import _format_list_items_multiline
1313

14-
__all__ = ["KernelParticle", "Particle", "ParticleClass", "Variable"]
14+
__all__ = ["Particle", "ParticleClass", "Variable"]
1515
_TO_WRITE_OPTIONS = [True, False, "once"]
1616

1717

@@ -116,30 +116,6 @@ def add_variable(self, variable: Variable | list[Variable]):
116116
return ParticleClass(variables=self.variables + variable)
117117

118118

119-
class KernelParticle:
120-
"""Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
121-
122-
def __init__(self, data, index):
123-
self._data = data
124-
self._index = index
125-
126-
def __getattr__(self, name):
127-
return self._data[name][self._index]
128-
129-
def __setattr__(self, name, value):
130-
if name in ["_data", "_index"]:
131-
object.__setattr__(self, name, value)
132-
else:
133-
self._data[name][self._index] = value
134-
135-
def __getitem__(self, index):
136-
self._index = index
137-
return self
138-
139-
def __len__(self):
140-
return len(self._index)
141-
142-
143119
def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]):
144120
existing_names = {var.name for var in existing_vars}
145121
for var in new_vars:

src/parcels/_core/particleset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from parcels._core.converters import _convert_to_flat_array
1313
from parcels._core.kernel import Kernel
14-
from parcels._core.particle import KernelParticle, Particle, create_particle_data
14+
from parcels._core.particle import Particle, create_particle_data
15+
from parcels._core.particlesetview import ParticleSetView
1516
from parcels._core.statuscodes import StatusCode
1617
from parcels._core.utils.time import (
1718
TimeInterval,
@@ -166,7 +167,7 @@ def __getattr__(self, name):
166167

167168
def __getitem__(self, index):
168169
"""Get a single particle by index."""
169-
return KernelParticle(self._data, index=index)
170+
return ParticleSetView(self._data, index=index)
170171

171172
def __setattr__(self, name, value):
172173
if name in ["_data"]:

0 commit comments

Comments
 (0)