|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +from parcels import Field, FieldSet, Particle, ParticleSet, Variable, VectorField, XGrid |
| 5 | +from parcels._core.statuscodes import StatusCode |
| 6 | +from parcels._datasets.structured.generic import datasets as datasets_structured |
| 7 | +from parcels.interpolators import XLinear |
| 8 | + |
| 9 | + |
| 10 | +@pytest.fixture |
| 11 | +def fieldset() -> FieldSet: |
| 12 | + ds = datasets_structured["ds_2d_left"] |
| 13 | + grid = XGrid.from_dataset(ds, mesh="flat") |
| 14 | + U = Field("U", ds["U_A_grid"], grid, interp_method=XLinear) |
| 15 | + V = Field("V", ds["V_A_grid"], grid, interp_method=XLinear) |
| 16 | + UV = VectorField("UV", U, V) |
| 17 | + return FieldSet([U, V, UV]) |
| 18 | + |
| 19 | + |
| 20 | +def test_execution_changing_particle_mask(fieldset): |
| 21 | + """Test that particle masks can change during kernel execution.""" |
| 22 | + npart = 10 |
| 23 | + initial_lons = np.linspace(0, 1, npart) |
| 24 | + pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart)) |
| 25 | + |
| 26 | + def IncrementLowLon(particles, fieldset): # pragma: no cover |
| 27 | + # Increment lon for particles with lon < 0.5 |
| 28 | + # The mask changes as particles cross the threshold |
| 29 | + particles[particles.lon < 0.5].dlon += 0.1 |
| 30 | + |
| 31 | + pset.execute(IncrementLowLon, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) |
| 32 | + |
| 33 | + # Particles that started below 0.5 should have moved more |
| 34 | + # Particles that started above 0.5 should not have moved |
| 35 | + particles_started_low = initial_lons < 0.5 |
| 36 | + particles_started_high = initial_lons >= 0.5 |
| 37 | + |
| 38 | + # Low particles should have increased lon |
| 39 | + assert np.all(pset.lon[particles_started_low] > initial_lons[particles_started_low]) |
| 40 | + # High particles should not have moved |
| 41 | + assert np.allclose(pset.lon[particles_started_high], initial_lons[particles_started_high], atol=1e-6) |
| 42 | + |
| 43 | + |
| 44 | +def test_particle_mask_conditional_state_changes(fieldset): |
| 45 | + """Test setting particle state based on a condition using particle masks.""" |
| 46 | + npart = 10 |
| 47 | + initial_lons = np.linspace(0, 1, npart) |
| 48 | + pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart)) |
| 49 | + |
| 50 | + def StopFastParticles(particles, fieldset): # pragma: no cover |
| 51 | + # Stop particles that have moved beyond lon=0.5 |
| 52 | + particles[particles.lon > 0.5].state = StatusCode.StopExecution |
| 53 | + |
| 54 | + def AdvanceLon(particles, fieldset): # pragma: no cover |
| 55 | + particles.dlon += 0.2 |
| 56 | + |
| 57 | + pset.execute([AdvanceLon, StopFastParticles], runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) |
| 58 | + |
| 59 | + # All particles should have stopped when they crossed lon > 0.5 |
| 60 | + # Verify all final positions are > 0.5 (since they stop after crossing) |
| 61 | + assert np.all(pset.lon > 0.5) |
| 62 | + # Particles that started closer to 0.5 should have stopped sooner (lower final lon) |
| 63 | + # while particles that started farther should have moved more before stopping |
| 64 | + assert pset.lon[0] < pset.lon[-1] # First particle stopped earliest, last stopped latest |
| 65 | + |
| 66 | + |
| 67 | +def test_particle_mask_conditional_updates(fieldset): |
| 68 | + """Test applying different updates to different particle subsets using masks.""" |
| 69 | + npart = 20 |
| 70 | + MyParticle = Particle.add_variable(Variable("temp", initial=10.0)) |
| 71 | + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=MyParticle) |
| 72 | + |
| 73 | + def ConditionalHeating(particles, fieldset): # pragma: no cover |
| 74 | + # Warm particles on the left, cool particles on the right |
| 75 | + particles[particles.lon < 0.5].temp += 1.0 |
| 76 | + particles[particles.lon >= 0.5].temp -= 0.5 |
| 77 | + |
| 78 | + pset.execute(ConditionalHeating, runtime=np.timedelta64(4, "s"), dt=np.timedelta64(1, "s")) |
| 79 | + |
| 80 | + # After 5 timesteps (0, 1, 2, 3, 4): left particles should be at 15.0, right at 7.5 |
| 81 | + left_particles = pset.lon < 0.5 |
| 82 | + right_particles = pset.lon >= 0.5 |
| 83 | + assert np.allclose(pset.temp[left_particles], 15.0, atol=1e-6) |
| 84 | + assert np.allclose(pset.temp[right_particles], 7.5, atol=1e-6) |
| 85 | + |
| 86 | + |
| 87 | +def test_particle_mask_progressive_changes(fieldset): |
| 88 | + """Test masks that change dynamically as particle properties change during execution.""" |
| 89 | + npart = 10 |
| 90 | + # Start all particles at lon=0, they will progressively move right |
| 91 | + pset = ParticleSet(fieldset, lon=np.zeros(npart), lat=np.linspace(0, 1, npart)) |
| 92 | + |
| 93 | + def MoveAndStopAtBoundary(particles, fieldset): # pragma: no cover |
| 94 | + # Move all particles right |
| 95 | + particles.dlon += 0.15 |
| 96 | + # Stop particles that cross lon=0.5 |
| 97 | + particles[particles.lon + particles.dlon > 0.5].state = StatusCode.StopExecution |
| 98 | + |
| 99 | + pset.execute(MoveAndStopAtBoundary, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s")) |
| 100 | + |
| 101 | + # All particles should have stopped at or before lon=0.5 |
| 102 | + # After first step: all reach 0.15 |
| 103 | + # After second step: all reach 0.30 |
| 104 | + # After third step: all reach 0.45 |
| 105 | + # After fourth step: all would reach 0.60, so they stop |
| 106 | + assert np.all(pset.lon <= 0.6) |
| 107 | + assert np.all(pset.lon >= 0.45) # At least 3 steps completed |
| 108 | + |
| 109 | + |
| 110 | +def test_particle_mask_multiple_sequential_operations(fieldset): |
| 111 | + """Test applying multiple different mask operations in sequence within one kernel.""" |
| 112 | + npart = 30 |
| 113 | + MyParticle = Particle.add_variable([Variable("group", initial=0), Variable("counter", initial=0)]) |
| 114 | + |
| 115 | + # Divide particles into three groups by initial position |
| 116 | + lons = np.linspace(0, 1, npart) |
| 117 | + pset = ParticleSet(fieldset, lon=lons, lat=np.zeros(npart), pclass=MyParticle) |
| 118 | + |
| 119 | + def MultiMaskOperations(particles, fieldset): # pragma: no cover |
| 120 | + # Classify particles into groups based on lon |
| 121 | + particles[particles.lon < 0.33].group = 1 |
| 122 | + particles[(particles.lon >= 0.33) & (particles.lon < 0.67)].group = 2 |
| 123 | + particles[particles.lon >= 0.67].group = 3 |
| 124 | + |
| 125 | + # Apply different operations to each group |
| 126 | + particles[particles.group == 1].counter += 1 |
| 127 | + particles[particles.group == 2].counter += 2 |
| 128 | + particles[particles.group == 3].counter += 3 |
| 129 | + |
| 130 | + pset.execute(MultiMaskOperations, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) |
| 131 | + |
| 132 | + # Verify groups were assigned correctly and counters incremented appropriately |
| 133 | + group1 = pset.lon < 0.33 |
| 134 | + group2 = (pset.lon >= 0.33) & (pset.lon < 0.67) |
| 135 | + group3 = pset.lon >= 0.67 |
| 136 | + |
| 137 | + assert np.allclose(pset.counter[group1], 6, atol=1e-6) # 6 timesteps * 1 |
| 138 | + assert np.allclose(pset.counter[group2], 12, atol=1e-6) # 6 timesteps * 2 |
| 139 | + assert np.allclose(pset.counter[group3], 18, atol=1e-6) # 6 timesteps * 3 |
| 140 | + |
| 141 | + |
| 142 | +def test_particle_mask_empty_mask_handling(fieldset): |
| 143 | + """Test that kernels handle empty masks (no particles matching condition) correctly.""" |
| 144 | + npart = 10 |
| 145 | + MyParticle = Particle.add_variable(Variable("modified", initial=0)) |
| 146 | + # All particles start at lon > 0 |
| 147 | + pset = ParticleSet(fieldset, lon=np.linspace(0.1, 1.0, npart), lat=np.zeros(npart), pclass=MyParticle) |
| 148 | + |
| 149 | + def ModifyNegativeLon(particles, fieldset): # pragma: no cover |
| 150 | + # This mask should be empty (no particles have lon < 0) |
| 151 | + particles[particles.lon < 0].modified = 1 |
| 152 | + # This should affect all particles |
| 153 | + particles.dlon += 0.01 |
| 154 | + |
| 155 | + # Should execute without errors even though the first mask is always empty |
| 156 | + pset.execute(ModifyNegativeLon, runtime=np.timedelta64(3, "s"), dt=np.timedelta64(1, "s")) |
| 157 | + |
| 158 | + # No particles should have been modified |
| 159 | + assert np.all(pset.modified == 0) |
| 160 | + # But all should have moved |
| 161 | + assert np.all(pset.lon > 0.1) |
| 162 | + |
| 163 | + |
| 164 | +def test_particle_mask_with_delete_state(fieldset): |
| 165 | + """Test using particle masks to delete particles based on conditions.""" |
| 166 | + npart = 20 |
| 167 | + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart)) |
| 168 | + initial_size = pset.size |
| 169 | + |
| 170 | + def DeleteEdgeParticles(particles, fieldset): # pragma: no cover |
| 171 | + # Delete particles at the edges |
| 172 | + particles[(particles.lon < 0.2) | (particles.lon > 0.8)].state = StatusCode.Delete |
| 173 | + |
| 174 | + def MoveLon(particles, fieldset): # pragma: no cover |
| 175 | + particles.dlon += 0.01 |
| 176 | + |
| 177 | + pset.execute([DeleteEdgeParticles, MoveLon], runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s")) |
| 178 | + |
| 179 | + # Should have deleted edge particles |
| 180 | + assert pset.size < initial_size |
| 181 | + # Remaining particles should be in the middle range |
| 182 | + assert np.all((pset.lon >= 0.2) & (pset.lon <= 0.8)) |
0 commit comments