We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f817390 commit 6449862Copy full SHA for 6449862
1 file changed
torchcontrol/controllers/mppi.py
@@ -75,7 +75,7 @@ def __init__(self, cfg: MPPICfg):
75
# extend the parameters batch size if they are tensors
76
if hasattr(rollout_plant_cfg, "params") and rollout_plant_cfg.params is not None:
77
for k, v in rollout_plant_cfg.params.__dict__.items():
78
- if isinstance(v, torch.Tensor) and v.shape[0] == self.num_envs:
+ if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] == self.num_envs:
79
setattr(rollout_plant_cfg.params, k, v.repeat_interleave(self.K, dim=0))
80
self._rollout_plant = rollout_plant_cfg.class_type(rollout_plant_cfg) # Create a new plant instance for rollouts
81
0 commit comments