@@ -54,19 +54,16 @@ def __init__(self, cfg: MPPICfg):
5454 self .u_min = cfg .u_min .to (self .device )
5555 self .u_max = cfg .u_max .to (self .device )
5656 self .cost_function = cfg .cost_function # Callable for computing costs
57-
57+
5858 # Initialize nominal control sequence (mean), shape: (num_envs, T, action_dim)
5959 self .u_nominal = torch .zeros (self .num_envs , self .T , self .action_dim , device = self .device )
6060
61- # If sigma is a scalar, convert to tensor of shape (action_dim,)
62- if self .sigma .dim () == 0 :
63- self .sigma = torch .full ((self .action_dim ,), float (self .sigma ), device = self .device )
64-
65- # If u_min and u_max are scalars, convert to tensors of shape (action_dim,)
66- if self .u_min .dim () == 0 :
67- self .u_min = torch .full ((self .action_dim ,), float (self .u_min ), device = self .device )
68- if self .u_max .dim () == 0 :
69- self .u_max = torch .full ((self .action_dim ,), float (self .u_max ), device = self .device )
61+ # Convert self.sigma to tensor of shape (num_envs, action_dim)
62+ self .sigma = self ._expand_shape (self .sigma )
63+
64+ # Convert u_min and u_max to tensors of shape (num_envs, action_dim)
65+ self .u_min = self ._expand_shape (self .u_min )
66+ self .u_max = self ._expand_shape (self .u_max )
7067
7168 # For rollout simulation, create a new plant with batch_size = num_envs * K using plant's cfg
7269 rollout_plant_cfg : PlantCfg = copy .deepcopy (self .plant .cfg ) # Use deepcopy to avoid modifying the original cfg
@@ -115,8 +112,8 @@ def forward(self, current_state: torch.Tensor, reference: torch.Tensor | None =
115112
116113 # Apply control limits
117114 u_perturbed = torch .clamp (u_perturbed ,
118- self .u_min .view (1 , 1 , 1 , - 1 ),
119- self .u_max .view (1 , 1 , 1 , - 1 ))
115+ self .u_min .view (self . num_envs , 1 , 1 , self . action_dim ),
116+ self .u_max .view (self . num_envs , 1 , 1 , self . action_dim ))
120117
121118 # 3. Simulate rollouts and 4. Compute costs
122119 total_costs = self ._compute_rollout_cost (current_state , u_perturbed , reference ) # Shape: (num_envs, K)
@@ -132,8 +129,7 @@ def forward(self, current_state: torch.Tensor, reference: torch.Tensor | None =
132129 action = self .u_nominal [:, 0 , :].clone () # Shape: (num_envs, action_dim)
133130
134131 # 8. Receding horizon: shift nominal control sequence for next time step
135- self .u_nominal [:, :- 1 , :] = self .u_nominal [:, 1 :, :].clone ()
136- self .u_nominal [:, - 1 , :] = 0 # Re-initialize the last action
132+ self .u_nominal = torch .roll (self .u_nominal , shifts = - 1 , dims = 1 ) # Shift left by 1 time step
137133
138134 return action
139135
@@ -151,10 +147,11 @@ def update(self, *args, **kwargs):
151147 val = kwargs [key ]
152148 # For tensors, check shape if possible
153149 if key in ['sigma' , 'u_min' , 'u_max' ]:
150+ shape = (self .num_envs , self .action_dim )
154151 val = torch .as_tensor (val , dtype = torch .float32 , device = self .device )
155152 if val .dim () == 0 :
156- val = torch .full (( self . action_dim ,) , float (val ), device = self .device )
157- assert val .shape == ( self . action_dim ,) , \
153+ val = torch .full (shape , float (val ), device = self .device )
154+ assert val .shape == shape , \
158155 f"Shape mismatch for { key } : { getattr (self , key ).shape } != { val .shape } "
159156 setattr (self , key , val )
160157 elif key == 'alpha' :
@@ -176,11 +173,11 @@ def reset(self, env_ids: list[int] | None = None):
176173 """
177174 # Resetting the plant is handled by the ControllerBase's reset method if called.
178175 super ().reset (env_ids ) # Call base class reset, which handles plant.reset if plant is set.
179-
176+
180177 # This reset focuses on MPPI's internal state (u_nominal).
181178 if env_ids is None or len (env_ids ) == self .num_envs :
182179 env_ids = self ._ALL_INDICES # Reset all environments
183-
180+
184181 self .u_nominal [env_ids ] = torch .zeros (len (env_ids ), self .T , self .action_dim , device = self .device )
185182
186183 def _sample_control_noise (self ) -> torch .Tensor :
@@ -195,9 +192,9 @@ def _sample_control_noise(self) -> torch.Tensor:
195192 """
196193 # Sample noise: (num_envs, K, T, action_dim)
197194 # Each environment gets K independent noise sequences for its T-step horizon.
198- # self.sigma is (action_dim, ), needs to be broadcast.
195+ # self.sigma is (num_envs, action_dim ), needs to be broadcast.
199196 noise = torch .randn (self .num_envs , self .K , self .T , self .action_dim , device = self .device ) * \
200- self .sigma .view (1 , 1 , 1 , - 1 ) # Broadcast sigma
197+ self .sigma .view (self . num_envs , 1 , 1 , self . action_dim )
201198 return noise
202199
203200 def _compute_rollout_cost (self , current_state : torch .Tensor , u_perturbed : torch .Tensor , reference : torch .Tensor = None ) -> torch .Tensor :
@@ -239,7 +236,8 @@ def _compute_rollout_cost(self, current_state: torch.Tensor, u_perturbed: torch.
239236
240237 # Simulate the rollouts using the perturbed control sequences
241238 state_rollouts = plant .rollout (u = u_for_sim ) # Shape: (num_envs * K, T, state_dim)
242- state_rollouts = state_rollouts .reshape (num_envs , K , T , state_dim ) # Reshape back to (num_envs, K, T, state_dim)
239+ state_rollouts = state_rollouts .reshape (num_envs , K , T , state_dim ) # Reshape back to (num_envs, K, T, state_dim)
240+ self .state_rollouts = state_rollouts # Store for potential future use
243241
244242 # Call the cost function (defined in mppi_cfg.py)
245243 # Expected input shapes:
@@ -275,3 +273,28 @@ def _compute_weights(self, total_costs: torch.Tensor) -> torch.Tensor:
275273 sum_exp = torch .sum (exp_terms , dim = 1 , keepdim = True ) # Shape: (num_envs, 1)
276274 weights = exp_terms / (sum_exp + 1e-9 ) # Add epsilon for numerical stability
277275 return weights
276+
277+ def _expand_shape (self , tensor : torch .Tensor ) -> torch .Tensor :
278+ """
279+ Expands the shape of a tensor to match the expected dimensions (num_envs, action_dim).
280+
281+ Args:
282+ tensor (torch.Tensor): The input tensor to expand.
283+ Returns:
284+ torch.Tensor: The expanded tensor with shape (num_envs, action_dim).
285+ """
286+ if tensor .dim () == 0 :
287+ # Scalar case, expand to (num_envs, action_dim)
288+ return tensor .expand (self .num_envs , self .action_dim )
289+ elif tensor .dim () == 1 :
290+ # 1D tensor case, must match action_dim
291+ assert tensor .shape [0 ] == self .action_dim , \
292+ f"Expected tensor shape ({ self .action_dim } ,), got { tensor .shape } "
293+ return tensor .unsqueeze (0 ).expand (self .num_envs , - 1 )
294+ elif tensor .dim () == 2 :
295+ # 2D tensor case, must match (num_envs, action_dim)
296+ assert tensor .shape == (self .num_envs , self .action_dim ), \
297+ f"Expected tensor shape ({ self .num_envs } , { self .action_dim } ), got { tensor .shape } "
298+ return tensor
299+ else :
300+ raise ValueError (f"Unsupported tensor shape: { tensor .shape } " )
0 commit comments