Skip to content

Commit 6c2d34f

Browse files
committed
[BugFix] Casting cached vals to device after device change
1 parent 2b061e8 commit 6c2d34f

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

vmas/scenarios/debug/circle_trajectory.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
from torch import Tensor
8-
98
from vmas import render_interactively
109
from vmas.simulator.core import Agent, Sphere, World
1110
from vmas.simulator.scenario import BaseScenario
@@ -127,13 +126,21 @@ def get_tangent_to_circle(self, agent: Agent, closest_point=None):
127126
torch.linalg.vector_norm(agent.state.pos, dim=1) < self.desired_radius,
128127
)
129128

130-
rotated_vector = TorchUtils.rotate_vector(
131-
distance_to_circle, torch.tensor(torch.pi / 2, device=self.world.device)
129+
angle_90 = torch.tensor(torch.pi / 2, device=self.world.device).expand(
130+
self.world.batch_dim
131+
)
132+
133+
rotated_vector_90 = TorchUtils.rotate_vector(
134+
distance_to_circle,
135+
angle_90,
132136
)
133-
rotated_vector[inside_circle] = TorchUtils.rotate_vector(
134-
distance_to_circle[inside_circle],
135-
torch.tensor(-torch.pi / 2, device=self.world.device),
137+
rotated_vector_neg_90 = TorchUtils.rotate_vector(
138+
distance_to_circle,
139+
-angle_90,
136140
)
141+
rotated_vector = rotated_vector_90
142+
rotated_vector[inside_circle] = rotated_vector_neg_90[inside_circle]
143+
137144
angle = rotated_vector / torch.linalg.vector_norm(
138145
rotated_vector, dim=1
139146
).unsqueeze(-1)

vmas/simulator/core.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313
from torch import Tensor
14-
1514
from vmas.simulator.joints import Joint
1615
from vmas.simulator.physics import (
1716
_get_closest_point_line,
@@ -1047,9 +1046,6 @@ def __init__(
10471046
]
10481047
# Map to save entity indexes
10491048
self.entity_index_map = {}
1050-
self._normal_vector = torch.tensor(
1051-
[1.0, 0.0], dtype=torch.float32, device=self.device
1052-
).repeat(self._batch_dim, 1)
10531049

10541050
def add_agent(self, agent: Agent):
10551051
"""Only way to add agents to the world"""

vmas/simulator/joints.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def _delta_anchor_tensor(self, entity):
158158
.expand(entity.state.pos.shape)
159159
)
160160
self._delta_anchor_tensor_map[entity] = delta_anchor_tensor
161+
self._delta_anchor_tensor_map[entity] = self._delta_anchor_tensor_map[
162+
entity
163+
].to(entity.state.pos.device)
161164
return self._delta_anchor_tensor_map[entity]
162165

163166
def get_delta_anchor(self, entity: vmas.simulator.core.Entity):

0 commit comments

Comments
 (0)