Skip to content

Commit c9d7e09

Browse files
committed
refine plot functions
1 parent 749031e commit c9d7e09

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

pybdr/util/visualization/plot.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def plot(
129129
raise Exception("unsupported visualization mode")
130130

131131

132-
def __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled):
132+
def __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled, show, save_file_name):
133133
assert len(dims) == 2
134134
if cs is not None:
135135
assert len(collections) == len(cs)
@@ -172,8 +172,10 @@ def __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled):
172172

173173
if ylim is not None:
174174
plt.ylim(ylim)
175-
176-
plt.show()
175+
if show:
176+
plt.show()
177+
if save_file_name is not None:
178+
plt.savefig(save_file_name, format="png")
177179

178180

179181
def __3d_plot_cmp(collections, dims, width, height, cs):
@@ -191,9 +193,11 @@ def plot_cmp(
191193
ylim=None,
192194
cs=None,
193195
filled=False,
196+
show=False,
197+
save_file_name=None
194198
):
195199
if mod == "2d":
196-
return __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled)
200+
return __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled, show, save_file_name)
197201
elif mod == "3d":
198202
return __3d_plot_cmp(collections, dims, width, height, cs)
199203
else:

0 commit comments

Comments
 (0)