1- # Copyright 2023 OmniSafe Team. All Rights Reserved.
1+ # Copyright 2024 OmniSafe Team. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -164,17 +164,25 @@ def plot_data(
164164
165165 plt .tight_layout (pad = 0.5 )
166166
167- def get_datasets (self , logdir : str , condition : str | None = None ) -> list [DataFrame ]:
167+ def get_datasets (
168+ self ,
169+ logdir : str ,
170+ condition : str | None = None ,
171+ reward_metrics : str = 'Metrics/EpReward' ,
172+ cost_metrics : str = 'Metrics/EpCost' ,
173+ ) -> list [DataFrame ]:
168174 """Recursively look through logdir for files named "progress.txt".
169175
170176 Assumes that any file "progress.txt" is a valid hit.
171177
172178 Args:
173179 logdir (str): The directory to search for progress.txt files
174180 condition (str or None, optional): The condition label. Defaults to None.
181+ reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
182+ cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
175183
176184 Returns:
177- The datasets.
185+ list[DataFrame]: A list of DataFrame objects containing the datasets.
178186
179187 Raise:
180188 FileNotFoundError: If the config file is not found.
@@ -204,21 +212,21 @@ def get_datasets(self, logdir: str, condition: str | None = None) -> list[DataFr
204212 self .units [condition1 ] += 1
205213 try :
206214 exp_data = pd .read_csv (os .path .join (root , 'progress.csv' ))
207-
208215 except FileNotFoundError as error :
209216 progress_path = os .path .join (root , 'progress.csv' )
210217 raise FileNotFoundError (f'Could not read from { progress_path } ' ) from error
211- performance = (
212- 'Metrics/TestEpRet' if 'Metrics/TestEpRet' in exp_data else 'Metrics/EpRet'
213- )
214- cost_performance = (
215- 'Metrics/TestEpCost' if 'Metrics/TestEpCost' in exp_data else 'Metrics/EpCost'
216- )
218+
219+ if reward_metrics not in exp_data :
220+ raise KeyError (f'{ reward_metrics } is not in data to plot!' )
221+
222+ if cost_metrics not in exp_data :
223+ raise KeyError (f'{ cost_metrics } is not in data to plot!' )
224+
217225 exp_data .insert (len (exp_data .columns ), 'Unit' , unit )
218226 exp_data .insert (len (exp_data .columns ), 'Condition1' , condition1 )
219227 exp_data .insert (len (exp_data .columns ), 'Condition2' , condition2 )
220- exp_data .insert (len (exp_data .columns ), 'Rewards' , exp_data [performance ])
221- exp_data .insert (len (exp_data .columns ), 'Costs' , exp_data [cost_performance ])
228+ exp_data .insert (len (exp_data .columns ), 'Rewards' , exp_data [reward_metrics ])
229+ exp_data .insert (len (exp_data .columns ), 'Costs' , exp_data [cost_metrics ])
222230 epoch = exp_data .get ('Train/Epoch' )
223231 if epoch is None or steps_per_epoch is None :
224232 raise ValueError ('No Train/Epoch column in progress.csv' )
@@ -236,6 +244,8 @@ def get_all_datasets(
236244 legend : list [str ] | None = None ,
237245 select : str | None = None ,
238246 exclude : str | None = None ,
247+ reward_metrics : str = 'Metrics/EpCost' ,
248+ cost_metrics : str = 'Metrics/EpCost' ,
239249 ) -> list [DataFrame ]:
240250 """Get all the data from all the log directories.
241251
@@ -248,6 +258,8 @@ def get_all_datasets(
248258 legend (list of str or None, optional): List of legend names. Defaults to None.
249259 select (str or None, optional): Select logdirs that contain this string. Defaults to None.
250260 exclude (str or None, optional): Exclude logdirs that contain this string. Defaults to None.
261+ reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
262+ cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
251263
252264 Returns:
253265 All the data stored in a list of DataFrames.
@@ -285,13 +297,22 @@ def get_all_datasets(
285297 data = []
286298 if legend :
287299 for log , leg in zip (logdirs , legend ):
288- data += self .get_datasets (log , leg )
300+ data += self .get_datasets (
301+ log ,
302+ leg ,
303+ cost_metrics = cost_metrics ,
304+ reward_metrics = reward_metrics ,
305+ )
289306 else :
290307 for log in logdirs :
291- data += self .get_datasets (log )
308+ data += self .get_datasets (
309+ log ,
310+ cost_metrics = cost_metrics ,
311+ reward_metrics = reward_metrics ,
312+ )
292313 return data
293314
294- # pylint: disable-next=too-many-arguments
315+ # pylint: disable-next=too-many-arguments, too-many-locals
295316 def make_plots (
296317 self ,
297318 all_logdirs : list [str ],
@@ -308,6 +329,8 @@ def make_plots(
308329 save_name : str | None = None ,
309330 save_format : str = 'png' ,
310331 show_image : bool = False ,
332+ reward_metrics : str = 'Metrics/EpCost' ,
333+ cost_metrics : str = 'Metrics/EpCost' ,
311334 ) -> None :
312335 """Make plots from the data in the specified log directories.
313336
@@ -355,9 +378,18 @@ def make_plots(
355378 to ``png``.
356379 show_image (bool, optional): Optional flag. If set, the plot will be displayed on screen.
357380 Defaults to ``False``.
381+ reward_metrics (str, optional): The column name for reward metrics. Defaults to 'Metrics/EpReward'.
382+ cost_metrics (str, optional): The column name for cost metrics. Defaults to 'Metrics/EpCost'.
358383 """
359384 assert xaxis is not None , 'Must specify xaxis'
360- data = self .get_all_datasets (all_logdirs , legend , select , exclude )
385+ data = self .get_all_datasets (
386+ all_logdirs ,
387+ legend ,
388+ select ,
389+ exclude ,
390+ cost_metrics = cost_metrics ,
391+ reward_metrics = reward_metrics ,
392+ )
361393 condition = 'Condition2' if count else 'Condition1'
362394 # choose what to show on main curve: mean? max? min?
363395 estimator = getattr (np , estimator )
0 commit comments