Skip to content

Commit 0d5cdfa

Browse files
author
Longbin Tang
authored
Bump version to 0.1.5 and refactor MPPI controller for tensor shape handling (#13)
1 parent a5d7f88 commit 0d5cdfa

3 files changed

Lines changed: 46 additions & 23 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "torchcontrol"
7-
version = "0.1.4"
7+
version = "0.1.5"
88
description = "A parallel control system simulation and control library based on PyTorch."
99
authors = [
1010
{ name = "Tang Longbin" }

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="torchcontrol",
5-
version="0.1.4",
5+
version="0.1.5",
66
description="A parallel control system simulation and control library based on PyTorch.",
77
author="Tang Longbin",
88
packages=find_packages(),

torchcontrol/controllers/mppi.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,16 @@ def __init__(self, cfg: MPPICfg):
5454
self.u_min = cfg.u_min.to(self.device)
5555
self.u_max = cfg.u_max.to(self.device)
5656
self.cost_function = cfg.cost_function # Callable for computing costs
57-
57+
5858
# Initialize nominal control sequence (mean), shape: (num_envs, T, action_dim)
5959
self.u_nominal = torch.zeros(self.num_envs, self.T, self.action_dim, device=self.device)
6060

61-
# If sigma is a scalar, convert to tensor of shape (action_dim,)
62-
if self.sigma.dim() == 0:
63-
self.sigma = torch.full((self.action_dim,), float(self.sigma), device=self.device)
64-
65-
# If u_min and u_max are scalars, convert to tensors of shape (action_dim,)
66-
if self.u_min.dim() == 0:
67-
self.u_min = torch.full((self.action_dim,), float(self.u_min), device=self.device)
68-
if self.u_max.dim() == 0:
69-
self.u_max = torch.full((self.action_dim,), float(self.u_max), device=self.device)
61+
# Convert self.sigma to tensor of shape (num_envs, action_dim)
62+
self.sigma = self._expand_shape(self.sigma)
63+
64+
# Convert u_min and u_max to tensors of shape (num_envs, action_dim)
65+
self.u_min = self._expand_shape(self.u_min)
66+
self.u_max = self._expand_shape(self.u_max)
7067

7168
# For rollout simulation, create a new plant with batch_size = num_envs * K using plant's cfg
7269
rollout_plant_cfg: PlantCfg = copy.deepcopy(self.plant.cfg) # Use deepcopy to avoid modifying the original cfg
@@ -115,8 +112,8 @@ def forward(self, current_state: torch.Tensor, reference: torch.Tensor | None =
115112

116113
# Apply control limits
117114
u_perturbed = torch.clamp(u_perturbed,
118-
self.u_min.view(1, 1, 1, -1),
119-
self.u_max.view(1, 1, 1, -1))
115+
self.u_min.view(self.num_envs, 1, 1, self.action_dim),
116+
self.u_max.view(self.num_envs, 1, 1, self.action_dim))
120117

121118
# 3. Simulate rollouts and 4. Compute costs
122119
total_costs = self._compute_rollout_cost(current_state, u_perturbed, reference) # Shape: (num_envs, K)
@@ -132,8 +129,7 @@ def forward(self, current_state: torch.Tensor, reference: torch.Tensor | None =
132129
action = self.u_nominal[:, 0, :].clone() # Shape: (num_envs, action_dim)
133130

134131
# 8. Receding horizon: shift nominal control sequence for next time step
135-
self.u_nominal[:, :-1, :] = self.u_nominal[:, 1:, :].clone()
136-
self.u_nominal[:, -1, :] = 0 # Re-initialize the last action
132+
self.u_nominal = torch.roll(self.u_nominal, shifts=-1, dims=1) # Shift left by 1 time step
137133

138134
return action
139135

@@ -151,10 +147,11 @@ def update(self, *args, **kwargs):
151147
val = kwargs[key]
152148
# For tensors, check shape if possible
153149
if key in ['sigma', 'u_min', 'u_max']:
150+
shape = (self.num_envs, self.action_dim)
154151
val = torch.as_tensor(val, dtype=torch.float32, device=self.device)
155152
if val.dim() == 0:
156-
val = torch.full((self.action_dim,), float(val), device=self.device)
157-
assert val.shape == (self.action_dim,), \
153+
val = torch.full(shape, float(val), device=self.device)
154+
assert val.shape == shape, \
158155
f"Shape mismatch for {key}: {getattr(self, key).shape} != {val.shape}"
159156
setattr(self, key, val)
160157
elif key == 'alpha':
@@ -176,11 +173,11 @@ def reset(self, env_ids: list[int] | None = None):
176173
"""
177174
# Resetting the plant is handled by the ControllerBase's reset method if called.
178175
super().reset(env_ids) # Call base class reset, which handles plant.reset if plant is set.
179-
176+
180177
# This reset focuses on MPPI's internal state (u_nominal).
181178
if env_ids is None or len(env_ids) == self.num_envs:
182179
env_ids = self._ALL_INDICES # Reset all environments
183-
180+
184181
self.u_nominal[env_ids] = torch.zeros(len(env_ids), self.T, self.action_dim, device=self.device)
185182

186183
def _sample_control_noise(self) -> torch.Tensor:
@@ -195,9 +192,9 @@ def _sample_control_noise(self) -> torch.Tensor:
195192
"""
196193
# Sample noise: (num_envs, K, T, action_dim)
197194
# Each environment gets K independent noise sequences for its T-step horizon.
198-
# self.sigma is (action_dim,), needs to be broadcast.
195+
# self.sigma is (num_envs, action_dim), needs to be broadcast.
199196
noise = torch.randn(self.num_envs, self.K, self.T, self.action_dim, device=self.device) * \
200-
self.sigma.view(1, 1, 1, -1) # Broadcast sigma
197+
self.sigma.view(self.num_envs, 1, 1, self.action_dim)
201198
return noise
202199

203200
def _compute_rollout_cost(self, current_state: torch.Tensor, u_perturbed: torch.Tensor, reference: torch.Tensor = None) -> torch.Tensor:
@@ -239,7 +236,8 @@ def _compute_rollout_cost(self, current_state: torch.Tensor, u_perturbed: torch.
239236

240237
# Simulate the rollouts using the perturbed control sequences
241238
state_rollouts = plant.rollout(u=u_for_sim) # Shape: (num_envs * K, T, state_dim)
242-
state_rollouts = state_rollouts.reshape(num_envs, K, T, state_dim) # Reshape back to (num_envs, K, T, state_dim)
239+
state_rollouts = state_rollouts.reshape(num_envs, K, T, state_dim) # Reshape back to (num_envs, K, T, state_dim)
240+
self.state_rollouts = state_rollouts # Store for potential future use
243241

244242
# Call the cost function (defined in mppi_cfg.py)
245243
# Expected input shapes:
@@ -275,3 +273,28 @@ def _compute_weights(self, total_costs: torch.Tensor) -> torch.Tensor:
275273
sum_exp = torch.sum(exp_terms, dim=1, keepdim=True) # Shape: (num_envs, 1)
276274
weights = exp_terms / (sum_exp + 1e-9) # Add epsilon for numerical stability
277275
return weights
276+
277+
def _expand_shape(self, tensor: torch.Tensor) -> torch.Tensor:
278+
"""
279+
Expands the shape of a tensor to match the expected dimensions (num_envs, action_dim).
280+
281+
Args:
282+
tensor (torch.Tensor): The input tensor to expand.
283+
Returns:
284+
torch.Tensor: The expanded tensor with shape (num_envs, action_dim).
285+
"""
286+
if tensor.dim() == 0:
287+
# Scalar case, expand to (num_envs, action_dim)
288+
return tensor.expand(self.num_envs, self.action_dim)
289+
elif tensor.dim() == 1:
290+
# 1D tensor case, must match action_dim
291+
assert tensor.shape[0] == self.action_dim, \
292+
f"Expected tensor shape ({self.action_dim},), got {tensor.shape}"
293+
return tensor.unsqueeze(0).expand(self.num_envs, -1)
294+
elif tensor.dim() == 2:
295+
# 2D tensor case, must match (num_envs, action_dim)
296+
assert tensor.shape == (self.num_envs, self.action_dim), \
297+
f"Expected tensor shape ({self.num_envs}, {self.action_dim}), got {tensor.shape}"
298+
return tensor
299+
else:
300+
raise ValueError(f"Unsupported tensor shape: {tensor.shape}")

0 commit comments

Comments
 (0)