Skip to content

Commit 00db22a

Browse files
Create test_particlesetview.py
1 parent 92c4937 commit 00db22a

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

tests/test_particlesetview.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import numpy as np
2+
import pytest
3+
4+
from parcels import Field, FieldSet, Particle, ParticleSet, Variable, VectorField, XGrid
5+
from parcels._core.statuscodes import StatusCode
6+
from parcels._datasets.structured.generic import datasets as datasets_structured
7+
from parcels.interpolators import XLinear
8+
9+
10+
@pytest.fixture
11+
def fieldset() -> FieldSet:
12+
ds = datasets_structured["ds_2d_left"]
13+
grid = XGrid.from_dataset(ds, mesh="flat")
14+
U = Field("U", ds["U_A_grid"], grid, interp_method=XLinear)
15+
V = Field("V", ds["V_A_grid"], grid, interp_method=XLinear)
16+
UV = VectorField("UV", U, V)
17+
return FieldSet([U, V, UV])
18+
19+
20+
def test_execution_changing_particle_mask(fieldset):
21+
"""Test that particle masks can change during kernel execution."""
22+
npart = 10
23+
initial_lons = np.linspace(0, 1, npart)
24+
pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart))
25+
26+
def IncrementLowLon(particles, fieldset): # pragma: no cover
27+
# Increment lon for particles with lon < 0.5
28+
# The mask changes as particles cross the threshold
29+
particles[particles.lon < 0.5].dlon += 0.1
30+
31+
pset.execute(IncrementLowLon, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s"))
32+
33+
# Particles that started below 0.5 should have moved more
34+
# Particles that started above 0.5 should not have moved
35+
particles_started_low = initial_lons < 0.5
36+
particles_started_high = initial_lons >= 0.5
37+
38+
# Low particles should have increased lon
39+
assert np.all(pset.lon[particles_started_low] > initial_lons[particles_started_low])
40+
# High particles should not have moved
41+
assert np.allclose(pset.lon[particles_started_high], initial_lons[particles_started_high], atol=1e-6)
42+
43+
44+
def test_particle_mask_conditional_state_changes(fieldset):
45+
"""Test setting particle state based on a condition using particle masks."""
46+
npart = 10
47+
initial_lons = np.linspace(0, 1, npart)
48+
pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart))
49+
50+
def StopFastParticles(particles, fieldset): # pragma: no cover
51+
# Stop particles that have moved beyond lon=0.5
52+
particles[particles.lon > 0.5].state = StatusCode.StopExecution
53+
54+
def AdvanceLon(particles, fieldset): # pragma: no cover
55+
particles.dlon += 0.2
56+
57+
pset.execute([AdvanceLon, StopFastParticles], runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s"))
58+
59+
# All particles should have stopped when they crossed lon > 0.5
60+
# Verify all final positions are > 0.5 (since they stop after crossing)
61+
assert np.all(pset.lon > 0.5)
62+
# Particles that started closer to 0.5 should have stopped sooner (lower final lon)
63+
# while particles that started farther should have moved more before stopping
64+
assert pset.lon[0] < pset.lon[-1] # First particle stopped earliest, last stopped latest
65+
66+
67+
def test_particle_mask_conditional_updates(fieldset):
68+
"""Test applying different updates to different particle subsets using masks."""
69+
npart = 20
70+
MyParticle = Particle.add_variable(Variable("temp", initial=10.0))
71+
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=MyParticle)
72+
73+
def ConditionalHeating(particles, fieldset): # pragma: no cover
74+
# Warm particles on the left, cool particles on the right
75+
particles[particles.lon < 0.5].temp += 1.0
76+
particles[particles.lon >= 0.5].temp -= 0.5
77+
78+
pset.execute(ConditionalHeating, runtime=np.timedelta64(4, "s"), dt=np.timedelta64(1, "s"))
79+
80+
# After 5 timesteps (0, 1, 2, 3, 4): left particles should be at 15.0, right at 7.5
81+
left_particles = pset.lon < 0.5
82+
right_particles = pset.lon >= 0.5
83+
assert np.allclose(pset.temp[left_particles], 15.0, atol=1e-6)
84+
assert np.allclose(pset.temp[right_particles], 7.5, atol=1e-6)
85+
86+
87+
def test_particle_mask_progressive_changes(fieldset):
88+
"""Test masks that change dynamically as particle properties change during execution."""
89+
npart = 10
90+
# Start all particles at lon=0, they will progressively move right
91+
pset = ParticleSet(fieldset, lon=np.zeros(npart), lat=np.linspace(0, 1, npart))
92+
93+
def MoveAndStopAtBoundary(particles, fieldset): # pragma: no cover
94+
# Move all particles right
95+
particles.dlon += 0.15
96+
# Stop particles that cross lon=0.5
97+
particles[particles.lon + particles.dlon > 0.5].state = StatusCode.StopExecution
98+
99+
pset.execute(MoveAndStopAtBoundary, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"))
100+
101+
# All particles should have stopped at or before lon=0.5
102+
# After first step: all reach 0.15
103+
# After second step: all reach 0.30
104+
# After third step: all reach 0.45
105+
# After fourth step: all would reach 0.60, so they stop
106+
assert np.all(pset.lon <= 0.6)
107+
assert np.all(pset.lon >= 0.45) # At least 3 steps completed
108+
109+
110+
def test_particle_mask_multiple_sequential_operations(fieldset):
111+
"""Test applying multiple different mask operations in sequence within one kernel."""
112+
npart = 30
113+
MyParticle = Particle.add_variable([Variable("group", initial=0), Variable("counter", initial=0)])
114+
115+
# Divide particles into three groups by initial position
116+
lons = np.linspace(0, 1, npart)
117+
pset = ParticleSet(fieldset, lon=lons, lat=np.zeros(npart), pclass=MyParticle)
118+
119+
def MultiMaskOperations(particles, fieldset): # pragma: no cover
120+
# Classify particles into groups based on lon
121+
particles[particles.lon < 0.33].group = 1
122+
particles[(particles.lon >= 0.33) & (particles.lon < 0.67)].group = 2
123+
particles[particles.lon >= 0.67].group = 3
124+
125+
# Apply different operations to each group
126+
particles[particles.group == 1].counter += 1
127+
particles[particles.group == 2].counter += 2
128+
particles[particles.group == 3].counter += 3
129+
130+
pset.execute(MultiMaskOperations, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s"))
131+
132+
# Verify groups were assigned correctly and counters incremented appropriately
133+
group1 = pset.lon < 0.33
134+
group2 = (pset.lon >= 0.33) & (pset.lon < 0.67)
135+
group3 = pset.lon >= 0.67
136+
137+
assert np.allclose(pset.counter[group1], 6, atol=1e-6) # 6 timesteps * 1
138+
assert np.allclose(pset.counter[group2], 12, atol=1e-6) # 6 timesteps * 2
139+
assert np.allclose(pset.counter[group3], 18, atol=1e-6) # 6 timesteps * 3
140+
141+
142+
def test_particle_mask_empty_mask_handling(fieldset):
143+
"""Test that kernels handle empty masks (no particles matching condition) correctly."""
144+
npart = 10
145+
MyParticle = Particle.add_variable(Variable("modified", initial=0))
146+
# All particles start at lon > 0
147+
pset = ParticleSet(fieldset, lon=np.linspace(0.1, 1.0, npart), lat=np.zeros(npart), pclass=MyParticle)
148+
149+
def ModifyNegativeLon(particles, fieldset): # pragma: no cover
150+
# This mask should be empty (no particles have lon < 0)
151+
particles[particles.lon < 0].modified = 1
152+
# This should affect all particles
153+
particles.dlon += 0.01
154+
155+
# Should execute without errors even though the first mask is always empty
156+
pset.execute(ModifyNegativeLon, runtime=np.timedelta64(3, "s"), dt=np.timedelta64(1, "s"))
157+
158+
# No particles should have been modified
159+
assert np.all(pset.modified == 0)
160+
# But all should have moved
161+
assert np.all(pset.lon > 0.1)
162+
163+
164+
def test_particle_mask_with_delete_state(fieldset):
165+
"""Test using particle masks to delete particles based on conditions."""
166+
npart = 20
167+
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart))
168+
initial_size = pset.size
169+
170+
def DeleteEdgeParticles(particles, fieldset): # pragma: no cover
171+
# Delete particles at the edges
172+
particles[(particles.lon < 0.2) | (particles.lon > 0.8)].state = StatusCode.Delete
173+
174+
def MoveLon(particles, fieldset): # pragma: no cover
175+
particles.dlon += 0.01
176+
177+
pset.execute([DeleteEdgeParticles, MoveLon], runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))
178+
179+
# Should have deleted edge particles
180+
assert pset.size < initial_size
181+
# Remaining particles should be in the middle range
182+
assert np.all((pset.lon >= 0.2) & (pset.lon <= 0.8))

0 commit comments

Comments
 (0)