66import argparse
77import gymnasium as gym
88import torch
9+ from dataclasses import dataclass , field
910from gymnasium .spaces .dict import Dict as GymSpacesDict
1011
1112from isaaclab .utils .datasets import HDF5DatasetFileHandler
1415from isaaclab_arena .policy .policy_base import PolicyBase
1516
1617
18+ @dataclass
19+ class ReplayActionPolicyArgs :
20+ """
21+ Configuration dataclass for ReplayActionPolicy.
22+
23+ This dataclass serves as the single source of truth for policy configuration,
24+ supporting both dict-based (from JSON) and CLI-based configuration paths.
25+
26+ Field metadata is used to auto-generate argparse arguments, ensuring consistency
27+ between the dataclass definition and CLI argument parsing.
28+ """
29+
30+ replay_file_path : str = field (
31+ metadata = {
32+ "help" : "Path to the HDF5 file containing the episode" ,
33+ "required" : True ,
34+ }
35+ )
36+
37+ device : str = field (
38+ default = "cuda" ,
39+ metadata = {
40+ "help" : "Device to use for loading the dataset" ,
41+ },
42+ )
43+
44+ episode_name : str | None = field (
45+ default = None ,
46+ metadata = {
47+ "help" : "Name of the episode to replay. If not provided, the first episode will be replayed" ,
48+ },
49+ )
50+
51+ @classmethod
52+ def from_cli_args (cls , args : argparse .Namespace ) -> "ReplayActionPolicyArgs" :
53+ """
54+ Create configuration from parsed CLI arguments.
55+
56+ Args:
57+ args: Parsed command line arguments
58+
59+ Returns:
60+ ReplayActionPolicyArgs instance
61+ """
62+ return cls (
63+ replay_file_path = args .replay_file_path ,
64+ device = getattr (args , "device" , "cuda" ),
65+ episode_name = args .episode_name ,
66+ )
67+
68+
1769@register_policy
1870class ReplayActionPolicy (PolicyBase ):
1971 """
@@ -22,24 +74,32 @@ class ReplayActionPolicy(PolicyBase):
2274 """
2375
2476 name = "replay"
25-
26- def __init__ (self , replay_file_path : str , device : str = "cuda" , episode_name : str | None = None ):
27- super ().__init__ ()
28- self .episode_name = episode_name
77+ # enable from_dict() from policy_base.PolicyBase
78+ config_class = ReplayActionPolicyArgs
79+
80+ def __init__ (self , config : ReplayActionPolicyArgs ):
81+ """
82+ Initialize ReplayActionPolicy from a configuration dataclass.
83+
84+ Args:
85+ config: ReplayActionPolicyArgs configuration dataclass
86+ """
87+ super ().__init__ (config )
88+ self .episode_name = config .episode_name
2989 self .dataset_file_handler = HDF5DatasetFileHandler ()
30- self .dataset_file_handler .open (replay_file_path )
90+ self .dataset_file_handler .open (config . replay_file_path )
3191 self .available_episode_names = list (self .dataset_file_handler .get_episode_names ())
3292
3393 # Take the first episode if no episode name is provided
3494 if self .episode_name is None :
3595 self .episode_name = self .available_episode_names [0 ]
3696 else :
3797 assert self .episode_name in self .available_episode_names , (
38- f"Episode { self .episode_name } not found in { replay_file_path } ."
98+ f"Episode { self .episode_name } not found in { config . replay_file_path } ."
3999 f"Available episodes: { self .available_episode_names } "
40100 )
41101
42- self .episode_data = self .dataset_file_handler .load_episode (self .episode_name , device = device )
102+ self .episode_data = self .dataset_file_handler .load_episode (self .episode_name , device = config . device )
43103 self .current_action_index = 0
44104
45105 def __len__ (self ) -> int :
@@ -84,23 +144,35 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars
84144 replay_group .add_argument (
85145 "--replay_file_path" ,
86146 type = str ,
87- help = "Path to the HDF5 file containing the episode (required with --policy_type replay)" ,
147+ required = True ,
148+ help = "Path to the HDF5 file containing the episode" ,
149+ )
150+ replay_group .add_argument (
151+ "--device" ,
152+ type = str ,
153+ default = "cuda" ,
154+ help = "Device to use for loading the dataset (default: cuda)" ,
88155 )
89156 replay_group .add_argument (
90157 "--episode_name" ,
91158 type = str ,
92159 default = None ,
93- help = (
94- "Name of the episode to replay. If not provided, the first episode will be"
95- "replayed (only used with --policy_type replay)"
96- ),
160+ help = "Name of the episode to replay. If not provided, the first episode will be replayed" ,
97161 )
98162 return parser
99163
100164 @staticmethod
101165 def from_args (args : argparse .Namespace ) -> "ReplayActionPolicy" :
102- """Create a replay action policy from the arguments."""
103- return ReplayActionPolicy (
104- replay_file_path = args .replay_file_path ,
105- episode_name = args .episode_name ,
106- )
166+ """
167+ Create a ReplayActionPolicy instance from parsed CLI arguments.
168+
169+ Path: CLI args → ConfigDataclass → init cls
170+
171+ Args:
172+ args: Parsed command line arguments
173+
174+ Returns:
175+ ReplayActionPolicy instance
176+ """
177+ config = ReplayActionPolicyArgs .from_cli_args (args )
178+ return ReplayActionPolicy (config )
0 commit comments