Skip to content

Commit 1a04410

Browse files
authored
Refactor Policy args with a dataclass (#332)
## Summary Enables users to specify policy related configs thru cli args parser (policy runner) or a dict (multi-task eval). ## Detailed description - Before this change, configs for policy can only be specified using cli args - Multi-task eval expects user to submit eval jobs thru json dict. - Instead of json dict -> cli args list -> parser, it allows two paths, using the follow orders 1. Json dict -> PolicyConfigClass -> Policy.from_dict() 2. Json dict -> cli args list -> args parser -> Policy.from_args() Depending on the availability of policy's method.
1 parent b69da35 commit 1a04410

8 files changed

Lines changed: 398 additions & 79 deletions

File tree

isaaclab_arena/evaluation/eval_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,23 @@ def load_env(arena_env_args: list[str], job_name: str):
4242

4343

4444
def get_policy_from_job(job: Job) -> "PolicyBase":
45-
45+
"""
46+
Create a policy from a job configuration. Two paths are supported:
47+
1. JSON → dict → ConfigDataclass → init cls (preferred, if policy has config_class)
48+
2. JSON → dict → CLI args → init cls (if policy has add_args_to_parser() and from_args())
49+
"""
4650
# Each job can be evaluated with a different policy checkpoint, or even a different policy type
4751
policy_cls = get_policy_cls(job.policy_type)
4852

49-
# As jobs may run diff policies, create a new parser for each job avoiding data fields conflicts
50-
policy_args_parser = get_isaaclab_arena_cli_parser()
51-
policy_added_args_parser = policy_cls.add_args_to_parser(policy_args_parser)
52-
# only for policy related arguments
53-
policy_args = policy_added_args_parser.parse_args(job.policy_args)
54-
policy = policy_cls.from_args(policy_args)
53+
# Use direct from_dict if the policy class has config_class defined
54+
if hasattr(policy_cls, "config_class") and policy_cls.config_class is not None:
55+
# Use the inherited from_dict() method from PolicyBase
56+
policy = policy_cls.from_dict(job.policy_config_dict)
57+
else:
58+
policy_args_parser = get_isaaclab_arena_cli_parser()
59+
policy_added_args_parser = policy_cls.add_args_to_parser(policy_args_parser)
60+
policy_args = policy_added_args_parser.parse_args(job.policy_config_dict)
61+
policy = policy_cls.from_args(policy_args)
5562
return policy
5663

5764

isaaclab_arena/evaluation/job_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(
2323
arena_env_args: dict,
2424
policy_type: str,
2525
num_steps: int = None,
26-
policy_args: dict = {},
27-
status: Status = Status.PENDING,
26+
policy_config_dict: dict = None,
27+
status: Status = None,
2828
):
2929
"""Initialize a Job instance.
3030
@@ -33,15 +33,15 @@ def __init__(
3333
arena_env_args: Dictionary of arguments for configuring the arena environment
3434
num_steps: Number of steps to run the policy for
3535
policy_type: Type of policy to use
36-
policy_args: Dictionary of arguments for the policy. These are passed to the policy class's from_args method.
36+
policy_config_dict: Dictionary configuration for the policy.
3737
status: Job status (defaults to PENDING)
3838
"""
3939
self.name = name
4040
self.arena_env_args = arena_env_args
4141
self.num_steps = num_steps
4242
self.policy_type = policy_type
43-
self.policy_args = policy_args
44-
self.status = status
43+
self.policy_config_dict = policy_config_dict if policy_config_dict is not None else {}
44+
self.status = status if status is not None else Status.PENDING
4545
self.start_time = None
4646
self.end_time = None
4747
self.metrics = {}
@@ -83,7 +83,7 @@ def from_dict(cls, data: dict) -> "Job":
8383
arena_env_args=cls.convert_args_dict_to_cli_args_list(data["arena_env_args"]),
8484
policy_type=data["policy_type"],
8585
num_steps=num_steps,
86-
policy_args=cls.convert_args_dict_to_cli_args_list(data["policy_args"]),
86+
policy_config_dict=data["policy_args"],
8787
status=status,
8888
)
8989

isaaclab_arena/policy/policy_base.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,50 @@
88
import torch
99
from abc import ABC, abstractmethod
1010
from gymnasium.spaces.dict import Dict as GymSpacesDict
11+
from typing import Any
1112

1213

1314
class PolicyBase(ABC):
14-
def __init__(self):
15+
"""
16+
Base class for policies.
17+
18+
Subclasses should define a `config_class` class variable pointing to their configuration dataclass
19+
to enable configuration from dictionaries via the from_dict() method.
20+
"""
21+
22+
# Optional: Subclasses can define this to enable from_dict()
23+
config_class: type | None = None
24+
25+
def __init__(self, config: Any):
1526
"""
1627
Base class for policies.
1728
"""
29+
self.config = config
30+
31+
@classmethod
32+
def from_dict(cls, config_dict: dict[str, Any]) -> "PolicyBase":
33+
"""
34+
Create a policy instance from a configuration dictionary.
35+
36+
This method instantiates the policy's config_class from the dict and then
37+
creates the policy from that config.
38+
39+
Path: dict → ConfigDataclass → Policy instance
40+
41+
Args:
42+
config_dict: Dictionary containing the configuration fields
43+
44+
Returns:
45+
Policy instance
46+
"""
47+
if cls.config_class is None:
48+
raise NotImplementedError(f"{cls.__name__} must define 'config_class' to use from_dict()")
49+
50+
# Create config from dict
51+
config = cls.config_class(**config_dict) # type: ignore[misc]
52+
53+
# Create policy from config
54+
return cls(config) # type: ignore[call-arg]
1855

1956
@abstractmethod
2057
def get_action(self, env: gym.Env, observation: GymSpacesDict) -> torch.Tensor:

isaaclab_arena/policy/replay_action_policy.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77
import gymnasium as gym
88
import torch
9+
from dataclasses import dataclass, field
910
from gymnasium.spaces.dict import Dict as GymSpacesDict
1011

1112
from isaaclab.utils.datasets import HDF5DatasetFileHandler
@@ -14,6 +15,57 @@
1415
from isaaclab_arena.policy.policy_base import PolicyBase
1516

1617

18+
@dataclass
19+
class ReplayActionPolicyArgs:
20+
"""
21+
Configuration dataclass for ReplayActionPolicy.
22+
23+
This dataclass serves as the single source of truth for policy configuration,
24+
supporting both dict-based (from JSON) and CLI-based configuration paths.
25+
26+
Field metadata is used to auto-generate argparse arguments, ensuring consistency
27+
between the dataclass definition and CLI argument parsing.
28+
"""
29+
30+
replay_file_path: str = field(
31+
metadata={
32+
"help": "Path to the HDF5 file containing the episode",
33+
"required": True,
34+
}
35+
)
36+
37+
device: str = field(
38+
default="cuda",
39+
metadata={
40+
"help": "Device to use for loading the dataset",
41+
},
42+
)
43+
44+
episode_name: str | None = field(
45+
default=None,
46+
metadata={
47+
"help": "Name of the episode to replay. If not provided, the first episode will be replayed",
48+
},
49+
)
50+
51+
@classmethod
52+
def from_cli_args(cls, args: argparse.Namespace) -> "ReplayActionPolicyArgs":
53+
"""
54+
Create configuration from parsed CLI arguments.
55+
56+
Args:
57+
args: Parsed command line arguments
58+
59+
Returns:
60+
ReplayActionPolicyArgs instance
61+
"""
62+
return cls(
63+
replay_file_path=args.replay_file_path,
64+
device=getattr(args, "device", "cuda"),
65+
episode_name=args.episode_name,
66+
)
67+
68+
1769
@register_policy
1870
class ReplayActionPolicy(PolicyBase):
1971
"""
@@ -22,24 +74,32 @@ class ReplayActionPolicy(PolicyBase):
2274
"""
2375

2476
name = "replay"
25-
26-
def __init__(self, replay_file_path: str, device: str = "cuda", episode_name: str | None = None):
27-
super().__init__()
28-
self.episode_name = episode_name
77+
# enable from_dict() from policy_base.PolicyBase
78+
config_class = ReplayActionPolicyArgs
79+
80+
def __init__(self, config: ReplayActionPolicyArgs):
81+
"""
82+
Initialize ReplayActionPolicy from a configuration dataclass.
83+
84+
Args:
85+
config: ReplayActionPolicyArgs configuration dataclass
86+
"""
87+
super().__init__(config)
88+
self.episode_name = config.episode_name
2989
self.dataset_file_handler = HDF5DatasetFileHandler()
30-
self.dataset_file_handler.open(replay_file_path)
90+
self.dataset_file_handler.open(config.replay_file_path)
3191
self.available_episode_names = list(self.dataset_file_handler.get_episode_names())
3292

3393
# Take the first episode if no episode name is provided
3494
if self.episode_name is None:
3595
self.episode_name = self.available_episode_names[0]
3696
else:
3797
assert self.episode_name in self.available_episode_names, (
38-
f"Episode {self.episode_name} not found in {replay_file_path}."
98+
f"Episode {self.episode_name} not found in {config.replay_file_path}."
3999
f"Available episodes: {self.available_episode_names}"
40100
)
41101

42-
self.episode_data = self.dataset_file_handler.load_episode(self.episode_name, device=device)
102+
self.episode_data = self.dataset_file_handler.load_episode(self.episode_name, device=config.device)
43103
self.current_action_index = 0
44104

45105
def __len__(self) -> int:
@@ -84,23 +144,35 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars
84144
replay_group.add_argument(
85145
"--replay_file_path",
86146
type=str,
87-
help="Path to the HDF5 file containing the episode (required with --policy_type replay)",
147+
required=True,
148+
help="Path to the HDF5 file containing the episode",
149+
)
150+
replay_group.add_argument(
151+
"--device",
152+
type=str,
153+
default="cuda",
154+
help="Device to use for loading the dataset (default: cuda)",
88155
)
89156
replay_group.add_argument(
90157
"--episode_name",
91158
type=str,
92159
default=None,
93-
help=(
94-
"Name of the episode to replay. If not provided, the first episode will be"
95-
"replayed (only used with --policy_type replay)"
96-
),
160+
help="Name of the episode to replay. If not provided, the first episode will be replayed",
97161
)
98162
return parser
99163

100164
@staticmethod
101165
def from_args(args: argparse.Namespace) -> "ReplayActionPolicy":
102-
"""Create a replay action policy from the arguments."""
103-
return ReplayActionPolicy(
104-
replay_file_path=args.replay_file_path,
105-
episode_name=args.episode_name,
106-
)
166+
"""
167+
Create a ReplayActionPolicy instance from parsed CLI arguments.
168+
169+
Path: CLI args → ConfigDataclass → init cls
170+
171+
Args:
172+
args: Parsed command line arguments
173+
174+
Returns:
175+
ReplayActionPolicy instance
176+
"""
177+
config = ReplayActionPolicyArgs.from_cli_args(args)
178+
return ReplayActionPolicy(config)

isaaclab_arena/policy/zero_action_policy.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,52 @@
66
import argparse
77
import gymnasium as gym
88
import torch
9+
from dataclasses import dataclass
910
from gymnasium.spaces.dict import Dict as GymSpacesDict
1011

1112
from isaaclab_arena.assets.register import register_policy
1213
from isaaclab_arena.policy.policy_base import PolicyBase
1314

1415

16+
@dataclass
17+
class ZeroActionPolicyArgs:
18+
"""
19+
Configuration dataclass for ZeroActionPolicy.
20+
21+
This policy has no configuration parameters, but the dataclass is provided
22+
for consistency with other policies following the unified configuration pattern.
23+
"""
24+
25+
@classmethod
26+
def from_cli_args(cls, args: argparse.Namespace) -> "ZeroActionPolicyArgs":
27+
"""
28+
Create configuration from parsed CLI arguments.
29+
30+
Args:
31+
args: Parsed command line arguments
32+
33+
Returns:
34+
ZeroActionPolicyArgs instance
35+
"""
36+
_ = args # Unused, but kept for API consistency
37+
return cls()
38+
39+
1540
@register_policy
1641
class ZeroActionPolicy(PolicyBase):
1742

1843
name = "zero_action"
44+
# enable from_dict() from policy_base.PolicyBase
45+
config_class = ZeroActionPolicyArgs
1946

20-
def __init__(self):
21-
super().__init__()
47+
def __init__(self, config: ZeroActionPolicyArgs):
48+
"""
49+
Initialize ZeroActionPolicy.
50+
51+
Args:
52+
config: ZeroActionPolicyArgs configuration dataclass (optional, not used)
53+
"""
54+
super().__init__(config)
2255

2356
def get_action(self, env: gym.Env, observation: GymSpacesDict) -> torch.Tensor:
2457
"""
@@ -28,11 +61,32 @@ def get_action(self, env: gym.Env, observation: GymSpacesDict) -> torch.Tensor:
2861

2962
@staticmethod
3063
def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
31-
"""Add zero action policy specific arguments to the parser."""
64+
"""
65+
Add zero action policy specific arguments to the parser.
66+
67+
This policy has no configuration parameters, so no arguments are added.
68+
69+
Args:
70+
parser: The argument parser to add arguments to
71+
72+
Returns:
73+
The updated argument parser (unchanged)
74+
"""
3275
# No additional command line arguments for zero action policy
3376
return parser
3477

3578
@staticmethod
3679
def from_args(args: argparse.Namespace) -> "ZeroActionPolicy":
37-
"""Create a zero action policy from the arguments."""
38-
return ZeroActionPolicy()
80+
"""
81+
Create a ZeroActionPolicy instance from parsed CLI arguments.
82+
83+
Path: CLI args → ConfigDataclass → init cls
84+
85+
Args:
86+
args: Parsed command line arguments
87+
88+
Returns:
89+
ZeroActionPolicy instance
90+
"""
91+
config = ZeroActionPolicyArgs.from_cli_args(args)
92+
return ZeroActionPolicy(config)

isaaclab_arena_environments/eval_jobs_configs/gr00t_jobs_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"name": "gr1_open_microwave_cracker_box",
55
"arena_env_args": {
66
"enable_cameras": true,
7-
"env_name":"gr1_open_microwave",
7+
"environment":"gr1_open_microwave",
88
"object":"cracker_box",
99
"embodiment":"gr1_joint"
1010
},
@@ -19,7 +19,7 @@
1919
"name": "g1_locomanip_pick_and_place_brown_box",
2020
"arena_env_args": {
2121
"enable_cameras": true,
22-
"env_name":"galileo_g1_locomanip_pick_and_place",
22+
"environment":"galileo_g1_locomanip_pick_and_place",
2323
"object":"brown_box",
2424
"embodiment":"g1_wbc_joint"
2525
},

0 commit comments

Comments
 (0)