7777 default = False ,
7878 help = "Disable LEAPP graph visualization during compile_graph()." ,
7979)
80- parser .add_argument (
81- "--disable_automatic_module_annotation" ,
82- action = "store_true" ,
83- default = False ,
84- help = "Disables automatic detection and annotation of modules that have internal states" ,
85- )
8680
8781# append RSL-RL cli arguments
8882cli_args .add_rsl_rl_args (parser )
123117
124118
125119def get_actor_memory_module (policy_nn ):
120+ """Return the actor-side recurrent memory module when the policy exposes one."""
126121 if hasattr (policy_nn , "memory_a" ):
127122 return policy_nn .memory_a
128123 if hasattr (policy_nn , "memory_s" ):
@@ -131,6 +126,7 @@ def get_actor_memory_module(policy_nn):
131126
132127
133128def ensure_actor_hidden_state_initialized (policy_nn , batch_size : int , device : torch .device , dtype : torch .dtype ):
129+ """Initialize and return the actor hidden state when a recurrent policy has not created it yet."""
134130 actor_state , _ = policy_nn .get_hidden_states ()
135131 if actor_state is not None :
136132 return actor_state
@@ -151,6 +147,7 @@ def ensure_actor_hidden_state_initialized(policy_nn, batch_size: int, device: to
151147
152148
153149def state_dict_from_actor_hidden (actor_hidden ):
150+ """Convert the actor hidden state into the named tensor mapping expected by LEAPP state APIs."""
154151 if actor_hidden is None :
155152 return {}
156153 if isinstance (actor_hidden , tuple ):
@@ -159,6 +156,7 @@ def state_dict_from_actor_hidden(actor_hidden):
159156
160157
161158def actor_hidden_from_registered (registered_state , original_hidden ):
159+ """Restore the registered LEAPP state to the hidden-state structure expected by the actor memory module."""
162160 if isinstance (original_hidden , tuple ):
163161 if isinstance (registered_state , tuple ):
164162 return registered_state
0 commit comments