@@ -253,6 +253,26 @@ def test_add_duplicate_field(dupobject):
253253 assert error_thrown
254254
255255
256+ @pytest .mark .parametrize ('fieldtype' , ['normal' , 'vector' ])
257+ def test_add_field_after_pset (fieldtype ):
258+ data , dimensions = generate_fieldset (100 , 100 )
259+ fieldset = FieldSet .from_data (data , dimensions )
260+ pset = ParticleSet (fieldset , ScipyParticle , lon = 0 , lat = 0 ) # noqa ; to trigger fieldset.check_complete
261+ field1 = Field ('field1' , fieldset .U .data , lon = fieldset .U .lon , lat = fieldset .U .lat )
262+ field2 = Field ('field2' , fieldset .U .data , lon = fieldset .U .lon , lat = fieldset .U .lat )
263+ vfield = VectorField ('vfield' , field1 , field2 )
264+ error_thrown = False
265+ try :
266+ if fieldtype == 'normal' :
267+ fieldset .add_field (field1 )
268+ elif fieldtype == 'vector' :
269+ fieldset .add_vector_field (vfield )
270+ except RuntimeError :
271+ error_thrown = True
272+
273+ assert error_thrown
274+
275+
256276def test_fieldset_samegrids_from_file (tmpdir , filename = 'test_subsets' ):
257277 """ Test for subsetting fieldset from file using indices dict. """
258278 data , dimensions = generate_fieldset (100 , 100 )
@@ -463,6 +483,36 @@ def test_vector_fields(mode, swapUV):
463483 assert abs (pset .lat [0 ] - .5 ) < 1e-9
464484
465485
486+ @pytest .mark .parametrize ('mode' , ['scipy' , 'jit' ])
487+ def test_add_second_vector_field (mode ):
488+ lon = np .linspace (0. , 10. , 12 , dtype = np .float32 )
489+ lat = np .linspace (0. , 10. , 10 , dtype = np .float32 )
490+ U = np .ones ((10 , 12 ), dtype = np .float32 )
491+ V = np .zeros ((10 , 12 ), dtype = np .float32 )
492+ data = {'U' : U , 'V' : V }
493+ dimensions = {'U' : {'lat' : lat , 'lon' : lon },
494+ 'V' : {'lat' : lat , 'lon' : lon }}
495+ fieldset = FieldSet .from_data (data , dimensions , mesh = 'flat' )
496+
497+ data2 = {'U2' : U , 'V2' : V }
498+ dimensions2 = {'lon' : [ln + 0.1 for ln in lon ], 'lat' : [lt - 0.1 for lt in lat ]}
499+ fieldset2 = FieldSet .from_data (data2 , dimensions2 , mesh = 'flat' )
500+
501+ UV2 = VectorField ('UV2' , fieldset2 .U2 , fieldset2 .V2 )
502+ fieldset .add_vector_field (UV2 )
503+
504+ def SampleUV2 (particle , fieldset , time ):
505+ u , v = fieldset .UV2 [time , particle .depth , particle .lat , particle .lon ]
506+ particle .lon += u * particle .dt
507+ particle .lat += v * particle .dt
508+
509+ pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = 0.5 , lat = 0.5 )
510+ pset .execute (AdvectionRK4 + pset .Kernel (SampleUV2 ), dt = 1 , runtime = 1 )
511+
512+ assert abs (pset .lon [0 ] - 2.5 ) < 1e-9
513+ assert abs (pset .lat [0 ] - .5 ) < 1e-9
514+
515+
466516@pytest .mark .parametrize ('mode' , ['scipy' , 'jit' ])
467517@pytest .mark .parametrize ('time_periodic' , [4 * 86400.0 , False ])
468518@pytest .mark .parametrize ('field_chunksize' , [False , 'auto' , (1 , 32 , 32 )])
0 commit comments