diff --git a/tools/template/cli.py b/tools/template/cli.py index d922025e070f..eea46ec0388a 100644 --- a/tools/template/cli.py +++ b/tools/template/cli.py @@ -7,6 +7,7 @@ import importlib import os from collections.abc import Callable +from textwrap import fill import rich.console import rich.table @@ -213,21 +214,21 @@ def main() -> None: # - show supported RL libraries and features rl_library_table = rich.table.Table(title="Supported RL libraries") rl_library_table.add_column("RL/training feature", no_wrap=True) - rl_library_table.add_column("rl_games") - rl_library_table.add_column("rsl_rl") - rl_library_table.add_column("skrl") - rl_library_table.add_column("sb3") + rl_library_table.add_column("rl_games", overflow="fold") + rl_library_table.add_column("rsl_rl", overflow="fold") + rl_library_table.add_column("skrl", overflow="fold") + rl_library_table.add_column("sb3", overflow="fold") rl_library_table.add_row("ML frameworks", "PyTorch", "PyTorch", "PyTorch, JAX", "PyTorch") rl_library_table.add_row("Relative performance", "~1X", "~1X", "~1X", "~0.03X") rl_library_table.add_row( "Algorithms", - ", ".join(algorithms_per_rl_library.get("rl_games", [])), - ", ".join(algorithms_per_rl_library.get("rsl_rl", [])), - ", ".join(algorithms_per_rl_library.get("skrl", [])), - ", ".join(algorithms_per_rl_library.get("sb3", [])), + fill(", ".join(algorithms_per_rl_library.get("rl_games", [])), width=12, break_long_words=False), + fill(", ".join(algorithms_per_rl_library.get("rsl_rl", [])), width=12, break_long_words=False), + fill(", ".join(algorithms_per_rl_library.get("skrl", [])), width=12, break_long_words=False), + fill(", ".join(algorithms_per_rl_library.get("sb3", [])), width=12, break_long_words=False), ) rl_library_table.add_row("Multi-agent support", State.No, State.No, State.Yes, State.No) - rl_library_table.add_row("Distributed training", State.Yes, State.No, State.Yes, State.No) + rl_library_table.add_row("Distributed training", State.Yes, State.Yes, State.Yes, State.No) rl_library_table.add_row("Vectorized training", State.Yes, State.Yes, State.Yes, State.No) rl_library_table.add_row("Fundamental/composite spaces", State.No, State.No, State.Yes, State.No) cli_handler.output_table(rl_library_table) diff --git a/tools/template/common.py b/tools/template/common.py index 08d2732a1911..b6f375b130b6 100644 --- a/tools/template/common.py +++ b/tools/template/common.py @@ -11,5 +11,5 @@ TEMPLATE_DIR = os.path.join(ROOT_DIR, "tools", "template", "templates") # RL algorithms -SINGLE_AGENT_ALGORITHMS = ["AMP", "PPO"] +SINGLE_AGENT_ALGORITHMS = ["AMP", "PPO", "DISTILLATION"] MULTI_AGENT_ALGORITHMS = ["IPPO", "MAPPO"] diff --git a/tools/template/templates/agents/rsl_rl_distillation_cfg b/tools/template/templates/agents/rsl_rl_distillation_cfg new file mode 100644 index 000000000000..020e2dbc21ac --- /dev/null +++ b/tools/template/templates/agents/rsl_rl_distillation_cfg @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_rl.rsl_rl import RslRlDistillationAlgorithmCfg, RslRlDistillationRunnerCfg, RslRlMLPModelCfg + + +@configclass +class DistillationRunnerCfg(RslRlDistillationRunnerCfg): + num_steps_per_env = 60 + max_iterations = 150 + save_interval = 50 + experiment_name = "cartpole_direct" + obs_groups = {"student": ["policy"], "teacher": ["policy"]} + student = RslRlMLPModelCfg( + hidden_dims=[32, 32], + activation="elu", + obs_normalization=False, + distribution_cfg=RslRlMLPModelCfg.GaussianDistributionCfg(init_std=1.0), + ) + teacher = RslRlMLPModelCfg( + hidden_dims=[32, 32], + activation="elu", + obs_normalization=False, + distribution_cfg=RslRlMLPModelCfg.GaussianDistributionCfg(init_std=0.0), + ) + algorithm = RslRlDistillationAlgorithmCfg( + num_learning_epochs=2, + learning_rate=1.0e-3, + gradient_length=15, + ) diff --git a/tools/template/templates/agents/rsl_rl_ppo_cfg b/tools/template/templates/agents/rsl_rl_ppo_cfg index 85970dfc2ce4..d83409033f6d 100644 --- a/tools/template/templates/agents/rsl_rl_ppo_cfg +++ b/tools/template/templates/agents/rsl_rl_ppo_cfg @@ -5,7 +5,7 @@ from isaaclab.utils import configclass -from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg +from isaaclab_rl.rsl_rl import RslRlMLPModelCfg, RslRlOnPolicyRunnerCfg, RslRlPpoAlgorithmCfg @configclass @@ -14,13 +14,17 @@ class PPORunnerCfg(RslRlOnPolicyRunnerCfg): max_iterations = 150 save_interval = 50 experiment_name = "cartpole_direct" - policy = RslRlPpoActorCriticCfg( - init_noise_std=1.0, - actor_obs_normalization=False, - critic_obs_normalization=False, - actor_hidden_dims=[32, 32], - critic_hidden_dims=[32, 32], + obs_groups = {"actor": ["policy"], "critic": ["policy"]} + actor = RslRlMLPModelCfg( + hidden_dims=[32, 32], activation="elu", + obs_normalization=False, + distribution_cfg=RslRlMLPModelCfg.GaussianDistributionCfg(init_std=1.0), + ) + critic = RslRlMLPModelCfg( + hidden_dims=[32, 32], + activation="elu", + obs_normalization=False, ) algorithm = RslRlPpoAlgorithmCfg( value_loss_coef=1.0,