Skip to content

Commit d18b1ac

Browse files
committed
feat: support customized plot
1 parent dd9068f commit d18b1ac

4 files changed

Lines changed: 87 additions & 19 deletions

File tree

examples/plot.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 OmniSafe Team. All Rights Reserved.
1+
# Copyright 2024 OmniSafe Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -35,6 +35,27 @@
3535
parser.add_argument('--select', nargs='*')
3636
parser.add_argument('--exclude', nargs='*')
3737
parser.add_argument('--estimator', default='mean')
38+
parser.add_argument(
39+
'--reward-metrics',
40+
type=str,
41+
choices=[
42+
'Metrics/TestEpRet',
43+
'Metrics/EpRet',
44+
],
45+
default='Metrics/EpRet',
46+
help='Specify the reward metric to be used.',
47+
)
48+
parser.add_argument(
49+
'--cost-metrics',
50+
type=str,
51+
choices=[
52+
'Metrics/Max_angle_violation',
53+
'Metrics/TestEpCost',
54+
'Metrics/EpCost',
55+
],
56+
default='Metrics/EpCost',
57+
help='Specify the cost metric to be used.',
58+
)
3859
args = parser.parse_args()
3960

4061
plotter = Plotter()
@@ -48,4 +69,6 @@
4869
select=args.select,
4970
exclude=args.exclude,
5071
estimator=args.estimator,
72+
cost_metrics=args.cost_metrics,
73+
reward_metrics=args.reward_metrics,
5174
)

omnisafe/common/experiment_grid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 OmniSafe Team. All Rights Reserved.
1+
# Copyright 2024 OmniSafe Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -544,6 +544,8 @@ def analyze(
544544
compare_num: int | None = None,
545545
cost_limit: float | None = None,
546546
show_image: bool = False,
547+
reward_metrics: str = 'Metrics/EpRet',
548+
cost_metrics: str = 'Metrics/EpCost',
547549
) -> None:
548550
"""Analyze the experiment results.
549551
@@ -559,6 +561,8 @@ def analyze(
559561
cost_limit (float or None, optional): Value for one line showed on graph to indicate
560562
cost. Defaults to None.
561563
show_image (bool): Whether to show graph image in GUI windows.
564+
reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
565+
cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
562566
"""
563567
assert self._statistical_tools is not None, 'Please run run() first!'
564568
self._statistical_tools.load_source(self.log_dir)
@@ -568,6 +572,8 @@ def analyze(
568572
compare_num,
569573
cost_limit,
570574
show_image=show_image,
575+
reward_metrics=reward_metrics,
576+
cost_metrics=cost_metrics,
571577
)
572578

573579
def evaluate(self, num_episodes: int = 10, cost_criteria: float = 1.0) -> None:

omnisafe/common/statistics_tools.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 OmniSafe Team. All Rights Reserved.
1+
# Copyright 2024 OmniSafe Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -83,6 +83,7 @@ def load_source(self, path: str) -> None:
8383
'The config file is not found in the save directory.',
8484
) from error
8585

86+
# pylint: disable-next=too-many-arguments, too-many-locals
8687
def draw_graph(
8788
self,
8889
parameter: str,
@@ -91,6 +92,8 @@ def draw_graph(
9192
cost_limit: float | None = None,
9293
smooth: int = 1,
9394
show_image: bool = False,
95+
reward_metrics: str = 'Metrics/EpRet',
96+
cost_metrics: str = 'Metrics/EpCost',
9497
) -> None:
9598
"""Draw graph.
9699
@@ -102,6 +105,8 @@ def draw_graph(
102105
cost_limit (float or None, optional): The cost limit of the experiment. Defaults to None.
103106
smooth (int, optional): The smooth window size. Defaults to 1.
104107
show_image (bool): Whether to show graph image in GUI windows.
108+
reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
109+
cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
105110
106111
.. note::
107112
`values` and `compare_num` cannot be set at the same time.
@@ -161,6 +166,8 @@ def draw_graph(
161166
'mean',
162167
save_name=save_name,
163168
show_image=show_image,
169+
reward_metrics=reward_metrics,
170+
cost_metrics=cost_metrics,
164171
)
165172
except Exception: # noqa # pragma: no cover # pylint: disable=broad-except
166173
print(

omnisafe/utils/plotter.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 OmniSafe Team. All Rights Reserved.
1+
# Copyright 2024 OmniSafe Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -164,17 +164,25 @@ def plot_data(
164164

165165
plt.tight_layout(pad=0.5)
166166

167-
def get_datasets(self, logdir: str, condition: str | None = None) -> list[DataFrame]:
167+
def get_datasets(
168+
self,
169+
logdir: str,
170+
condition: str | None = None,
171+
reward_metrics: str = 'Metrics/EpReward',
172+
cost_metrics: str = 'Metrics/EpCost',
173+
) -> list[DataFrame]:
168174
"""Recursively look through logdir for files named "progress.txt".
169175
170176
Assumes that any file "progress.txt" is a valid hit.
171177
172178
Args:
173179
logdir (str): The directory to search for progress.txt files
174180
condition (str or None, optional): The condition label. Defaults to None.
181+
reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
182+
cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
175183
176184
Returns:
177-
The datasets.
185+
list[DataFrame]: A list of DataFrame objects containing the datasets.
178186
179187
Raise:
180188
FileNotFoundError: If the config file is not found.
@@ -204,21 +212,21 @@ def get_datasets(self, logdir: str, condition: str | None = None) -> list[DataFr
204212
self.units[condition1] += 1
205213
try:
206214
exp_data = pd.read_csv(os.path.join(root, 'progress.csv'))
207-
208215
except FileNotFoundError as error:
209216
progress_path = os.path.join(root, 'progress.csv')
210217
raise FileNotFoundError(f'Could not read from {progress_path}') from error
211-
performance = (
212-
'Metrics/TestEpRet' if 'Metrics/TestEpRet' in exp_data else 'Metrics/EpRet'
213-
)
214-
cost_performance = (
215-
'Metrics/TestEpCost' if 'Metrics/TestEpCost' in exp_data else 'Metrics/EpCost'
216-
)
218+
219+
if reward_metrics not in exp_data:
220+
raise KeyError(f'{reward_metrics} is not in data to plot!')
221+
222+
if cost_metrics not in exp_data:
223+
raise KeyError(f'{cost_metrics} is not in data to plot!')
224+
217225
exp_data.insert(len(exp_data.columns), 'Unit', unit)
218226
exp_data.insert(len(exp_data.columns), 'Condition1', condition1)
219227
exp_data.insert(len(exp_data.columns), 'Condition2', condition2)
220-
exp_data.insert(len(exp_data.columns), 'Rewards', exp_data[performance])
221-
exp_data.insert(len(exp_data.columns), 'Costs', exp_data[cost_performance])
228+
exp_data.insert(len(exp_data.columns), 'Rewards', exp_data[reward_metrics])
229+
exp_data.insert(len(exp_data.columns), 'Costs', exp_data[cost_metrics])
222230
epoch = exp_data.get('Train/Epoch')
223231
if epoch is None or steps_per_epoch is None:
224232
raise ValueError('No Train/Epoch column in progress.csv')
@@ -236,6 +244,8 @@ def get_all_datasets(
236244
legend: list[str] | None = None,
237245
select: str | None = None,
238246
exclude: str | None = None,
247+
reward_metrics: str = 'Metrics/EpCost',
248+
cost_metrics: str = 'Metrics/EpCost',
239249
) -> list[DataFrame]:
240250
"""Get all the data from all the log directories.
241251
@@ -248,6 +258,8 @@ def get_all_datasets(
248258
legend (list of str or None, optional): List of legend names. Defaults to None.
249259
select (str or None, optional): Select logdirs that contain this string. Defaults to None.
250260
exclude (str or None, optional): Exclude logdirs that contain this string. Defaults to None.
261+
reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
262+
cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
251263
252264
Returns:
253265
All the data stored in a list of DataFrames.
@@ -285,13 +297,22 @@ def get_all_datasets(
285297
data = []
286298
if legend:
287299
for log, leg in zip(logdirs, legend):
288-
data += self.get_datasets(log, leg)
300+
data += self.get_datasets(
301+
log,
302+
leg,
303+
cost_metrics=cost_metrics,
304+
reward_metrics=reward_metrics,
305+
)
289306
else:
290307
for log in logdirs:
291-
data += self.get_datasets(log)
308+
data += self.get_datasets(
309+
log,
310+
cost_metrics=cost_metrics,
311+
reward_metrics=reward_metrics,
312+
)
292313
return data
293314

294-
# pylint: disable-next=too-many-arguments
315+
# pylint: disable-next=too-many-arguments, too-many-locals
295316
def make_plots(
296317
self,
297318
all_logdirs: list[str],
@@ -308,6 +329,8 @@ def make_plots(
308329
save_name: str | None = None,
309330
save_format: str = 'png',
310331
show_image: bool = False,
332+
reward_metrics: str = 'Metrics/EpCost',
333+
cost_metrics: str = 'Metrics/EpCost',
311334
) -> None:
312335
"""Make plots from the data in the specified log directories.
313336
@@ -355,9 +378,18 @@ def make_plots(
355378
to ``png``.
356379
show_image (bool, optional): Optional flag. If set, the plot will be displayed on screen.
357380
Defaults to ``False``.
381+
reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
382+
cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
358383
"""
359384
assert xaxis is not None, 'Must specify xaxis'
360-
data = self.get_all_datasets(all_logdirs, legend, select, exclude)
385+
data = self.get_all_datasets(
386+
all_logdirs,
387+
legend,
388+
select,
389+
exclude,
390+
cost_metrics=cost_metrics,
391+
reward_metrics=reward_metrics,
392+
)
361393
condition = 'Condition2' if count else 'Condition1'
362394
# choose what to show on main curve: mean? max? min?
363395
estimator = getattr(np, estimator)

0 commit comments

Comments
 (0)