Skip to content

Commit 6449862

Browse files
author
Longbin Tang
authored
Fix tensor shape check in MPPI controller to ensure proper batch size handling (#10)
1 parent f817390 commit 6449862

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

torchcontrol/controllers/mppi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, cfg: MPPICfg):
7575
# extend the parameters batch size if they are tensors
7676
if hasattr(rollout_plant_cfg, "params") and rollout_plant_cfg.params is not None:
7777
for k, v in rollout_plant_cfg.params.__dict__.items():
78-
if isinstance(v, torch.Tensor) and v.shape[0] == self.num_envs:
78+
if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] == self.num_envs:
7979
setattr(rollout_plant_cfg.params, k, v.repeat_interleave(self.K, dim=0))
8080
self._rollout_plant = rollout_plant_cfg.class_type(rollout_plant_cfg) # Create a new plant instance for rollouts
8181

0 commit comments

Comments
 (0)