1111
1212
1313def plotTrajectoriesFile (filename , mode = '2d' , tracerfile = None , tracerfield = 'P' ,
14- tracerlon = 'x' , tracerlat = 'y' , recordedvar = None ):
14+ tracerlon = 'x' , tracerlat = 'y' , recordedvar = None , show_plt = True ):
1515 """Quick and simple plotting of Parcels trajectories
1616
1717 :param filename: Name of Parcels-generated NetCDF file with particle positions
@@ -24,6 +24,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
2424 :param tracerlat: Name of latitude dimension of variable to show as background
2525 :param recordedvar: Name of variable used to color particles in scatter-plot.
2626 Only works in 'movie2d' or 'movie2d_notebook' mode.
27+ :param show_plt: Boolean whether plot should directly be show (for py.test)
2728 """
2829
2930 if plt is None :
@@ -34,10 +35,10 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
3435 lon = pfile .variables ['lon' ]
3536 lat = pfile .variables ['lat' ]
3637 z = pfile .variables ['z' ]
38+ time = pfile .variables ['time' ][:]
3739 if len (lon .shape ) == 1 :
3840 type = 'indexed'
3941 id = pfile .variables ['trajectory' ][:]
40- time = pfile .variables ['time' ][:]
4142 else :
4243 type = 'array'
4344
@@ -59,7 +60,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
5960 for p in range (len (lon )):
6061 ax .plot (lon [p , :], lat [p , :], z [p , :], '.-' )
6162 elif type == 'indexed' :
62- for t in range ( max ( id ) + 1 ):
63+ for t in np . unique ( id ):
6364 ax .plot (lon [id == t ], lat [id == t ],
6465 z [id == t ], '.-' )
6566 ax .set_xlabel ('Longitude' )
@@ -69,11 +70,18 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
6970 if type == 'array' :
7071 plt .plot (np .transpose (lon ), np .transpose (lat ), '.-' )
7172 elif type == 'indexed' :
72- for t in range ( max ( id ) + 1 ):
73+ for t in np . unique ( id ):
7374 plt .plot (lon [id == t ], lat [id == t ], '.-' )
7475 plt .xlabel ('Longitude' )
7576 plt .ylabel ('Latitude' )
7677 elif mode == 'movie2d' or 'movie2d_notebook' :
78+ if type == 'array' and any (time [:, 0 ] != time [0 , 0 ]):
79+ # since particles don't start at the same time, treat as indexed
80+ type = 'indexed'
81+ id = pfile .variables ['trajectory' ][:].flatten ()
82+ lon = lon [:].flatten ()
83+ lat = lat [:].flatten ()
84+ time = time .flatten ()
7785
7886 fig = plt .figure ()
7987 ax = plt .axes (xlim = (np .amin (lon ), np .amax (lon )), ylim = (np .amin (lat ), np .amax (lat )))
@@ -84,7 +92,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
8492 mintime = min (time )
8593 scat = ax .scatter (lon [time == mintime ], lat [time == mintime ],
8694 s = 60 , cmap = plt .get_cmap ('autumn' ))
87- frames = np .unique (time )
95+ frames = np .unique (time [ ~ np . isnan ( time )] )
8896
8997 def animate (t ):
9098 if type == 'array' :
@@ -102,7 +110,9 @@ def animate(t):
102110 plt .close ()
103111 return anim
104112 else :
105- plt .show ()
113+ if show_plt :
114+ plt .show ()
115+ return plt
106116
107117
108118if __name__ == "__main__" :
@@ -125,4 +135,5 @@ def animate(t):
125135
126136 plotTrajectoriesFile (args .particlefile , mode = args .mode , tracerfile = args .tracerfile ,
127137 tracerfield = args .tracerfilefield , tracerlon = args .tracerfilelon ,
128- tracerlat = args .tracerfilelat , recordedvar = args .recordedvar )
138+ tracerlat = args .tracerfilelat , recordedvar = args .recordedvar ,
139+ show_plt = True )
0 commit comments