@@ -129,7 +129,7 @@ def plot(
129129 raise Exception ("unsupported visualization mode" )
130130
131131
132- def __2d_plot_cmp (collections , dims , width , height , xlim , ylim , cs , filled ):
132+ def __2d_plot_cmp (collections , dims , width , height , xlim , ylim , cs , filled , show , save_file_name ):
133133 assert len (dims ) == 2
134134 if cs is not None :
135135 assert len (collections ) == len (cs )
@@ -172,8 +172,10 @@ def __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled):
172172
173173 if ylim is not None :
174174 plt .ylim (ylim )
175-
176- plt .show ()
175+ if show :
176+ plt .show ()
177+ if save_file_name is not None :
178+ plt .savefig (save_file_name , format = "png" )
177179
178180
179181def __3d_plot_cmp (collections , dims , width , height , cs ):
@@ -191,9 +193,11 @@ def plot_cmp(
191193 ylim = None ,
192194 cs = None ,
193195 filled = False ,
196+ show = False ,
197+ save_file_name = None
194198):
195199 if mod == "2d" :
196- return __2d_plot_cmp (collections , dims , width , height , xlim , ylim , cs , filled )
200+ return __2d_plot_cmp (collections , dims , width , height , xlim , ylim , cs , filled , show , save_file_name )
197201 elif mod == "3d" :
198202 return __3d_plot_cmp (collections , dims , width , height , cs )
199203 else :
0 commit comments