Skip to content

Commit 1183ff6

Browse files
committed
feat: add pose, twist and state dataclasses
1 parent 6d41d29 commit 1183ff6

1 file changed

Lines changed: 86 additions & 1 deletion

File tree

vortex_utils/python_utils.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,89 @@
11
import numpy as np
2+
from dataclasses import dataclass
3+
from scipy.spatial.transform import Rotation
24

35
def ssa(angle: float) -> float:
4-
return (angle + np.pi) % (2 * np.pi) - np.pi
6+
return (angle + np.pi) % (2 * np.pi) - np.pi
7+
8+
def euler_to_quat(roll: float, pitch: float, yaw: float) -> np.ndarray:
9+
"""
10+
Converts euler angles to quaternion (x, y, z, w)
11+
"""
12+
rotation_matrix = Rotation.from_euler("XYZ", [roll, pitch, yaw]).as_matrix()
13+
quaternion = Rotation.from_matrix(rotation_matrix).as_quat()
14+
return quaternion
15+
16+
def quat_to_euler(x: float, y: float, z: float, w: float) -> np.ndarray:
17+
"""
18+
Converts quaternion (x, y, z, w) to euler angles
19+
"""
20+
rotation_matrix = Rotation.from_quat([x, y, z, w]).as_matrix()
21+
euler_angles = Rotation.from_matrix(rotation_matrix).as_euler("ZYX")
22+
return euler_angles
23+
24+
@dataclass(slots=True)
25+
class Pose:
26+
x: float = 0.0
27+
y: float = 0.0
28+
z: float = 0.0
29+
roll: float = 0.0
30+
pitch: float = 0.0
31+
yaw: float = 0.0
32+
33+
def __add__(self, other: "Pose") -> "Pose":
34+
return Pose(**{field.name: getattr(self, field.name) + getattr(other, field.name) for field in self.__dataclass_fields__.values()})
35+
36+
def __sub__(self, other: "Pose") -> "Pose":
37+
return Pose(**{field.name: getattr(self, field.name) - getattr(other, field.name) for field in self.__dataclass_fields__.values()})
38+
39+
def __mul__(self, other: float) -> "Pose":
40+
if isinstance(other, Pose):
41+
return Pose(**{field.name: getattr(self, field.name) * getattr(other, field.name) for field in self.__dataclass_fields__.values()})
42+
return Pose(**{field.name: getattr(self, field.name) * other for field in self.__dataclass_fields__.values()})
43+
44+
def as_rotation_matrix(self) -> np.ndarray:
45+
euler_angles = [self.roll, self.pitch, self.yaw]
46+
rotation_matrix = Rotation.from_euler("ZYX", euler_angles).as_matrix()
47+
return rotation_matrix
48+
49+
@dataclass(slots=True)
50+
class Twist:
51+
linear_x: float = 0.0
52+
linear_y: float = 0.0
53+
linear_z: float = 0.0
54+
angular_x: float = 0.0
55+
angular_y: float = 0.0
56+
angular_z: float = 0.0
57+
58+
def __add__(self, other: "Twist") -> "Twist":
59+
return Twist(**{field.name: getattr(self, field.name) + getattr(other, field.name) for field in self.__dataclass_fields__.values()})
60+
61+
def __sub__(self, other: "Twist") -> "Twist":
62+
return Twist(**{field.name: getattr(self, field.name) - getattr(other, field.name) for field in self.__dataclass_fields__.values()})
63+
64+
def __mul__(self, other: float) -> "Twist":
65+
if isinstance(other, Twist):
66+
return Twist(**{field.name: getattr(self, field.name) * getattr(other, field.name) for field in self.__dataclass_fields__.values()})
67+
return Twist(**{field.name: getattr(self, field.name) * other for field in self.__dataclass_fields__.values()})
68+
69+
@dataclass(slots=True)
70+
class State:
71+
pose: Pose
72+
twist: Twist
73+
74+
def __add__(self, other: "State") -> "State":
75+
return State(pose=self.pose + other.pose, twist=self.twist + other.twist)
76+
77+
def __sub__(self, other: "State") -> "State":
78+
return State(pose=self.pose - other.pose, twist=self.twist - other.twist)
79+
80+
81+
82+
test_pose = Pose(x=1.0, y=2.0, z=3.0, roll=0.1, pitch=0.2, yaw=0.3)
83+
test_twist = Twist(linear_x=0.1, linear_y=0.2, linear_z=0.3, angular_x=0.01, angular_y=0.02, angular_z=0.03)
84+
test_pose2 = Pose(x=0.1, y=0.2, z=0.3, roll=0.01, pitch=0.02, yaw=0.03)
85+
test_twist2 = Twist(linear_x=0.01, linear_y=0.02, linear_z=0.03, angular_x=0.001, angular_y=0.002, angular_z=0.003)
86+
87+
euler = [1.0, 0.0, 0.0]
88+
quat = euler_to_quat(*euler)
89+
print(quat)

0 commit comments

Comments
 (0)