11import numpy as np
2+ from dataclasses import dataclass
3+ from scipy .spatial .transform import Rotation
24
35def ssa (angle : float ) -> float :
4- return (angle + np .pi ) % (2 * np .pi ) - np .pi
6+ return (angle + np .pi ) % (2 * np .pi ) - np .pi
7+
8+ def euler_to_quat (roll : float , pitch : float , yaw : float ) -> np .ndarray :
9+ """
10+ Converts euler angles to quaternion (x, y, z, w)
11+ """
12+ rotation_matrix = Rotation .from_euler ("XYZ" , [roll , pitch , yaw ]).as_matrix ()
13+ quaternion = Rotation .from_matrix (rotation_matrix ).as_quat ()
14+ return quaternion
15+
16+ def quat_to_euler (x : float , y : float , z : float , w : float ) -> np .ndarray :
17+ """
18+ Converts quaternion (x, y, z, w) to euler angles
19+ """
20+ rotation_matrix = Rotation .from_quat ([x , y , z , w ]).as_matrix ()
21+ euler_angles = Rotation .from_matrix (rotation_matrix ).as_euler ("ZYX" )
22+ return euler_angles
23+
24+ @dataclass (slots = True )
25+ class Pose :
26+ x : float = 0.0
27+ y : float = 0.0
28+ z : float = 0.0
29+ roll : float = 0.0
30+ pitch : float = 0.0
31+ yaw : float = 0.0
32+
33+ def __add__ (self , other : "Pose" ) -> "Pose" :
34+ return Pose (** {field .name : getattr (self , field .name ) + getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
35+
36+ def __sub__ (self , other : "Pose" ) -> "Pose" :
37+ return Pose (** {field .name : getattr (self , field .name ) - getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
38+
39+ def __mul__ (self , other : float ) -> "Pose" :
40+ if isinstance (other , Pose ):
41+ return Pose (** {field .name : getattr (self , field .name ) * getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
42+ return Pose (** {field .name : getattr (self , field .name ) * other for field in self .__dataclass_fields__ .values ()})
43+
44+ def as_rotation_matrix (self ) -> np .ndarray :
45+ euler_angles = [self .roll , self .pitch , self .yaw ]
46+ rotation_matrix = Rotation .from_euler ("ZYX" , euler_angles ).as_matrix ()
47+ return rotation_matrix
48+
49+ @dataclass (slots = True )
50+ class Twist :
51+ linear_x : float = 0.0
52+ linear_y : float = 0.0
53+ linear_z : float = 0.0
54+ angular_x : float = 0.0
55+ angular_y : float = 0.0
56+ angular_z : float = 0.0
57+
58+ def __add__ (self , other : "Twist" ) -> "Twist" :
59+ return Twist (** {field .name : getattr (self , field .name ) + getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
60+
61+ def __sub__ (self , other : "Twist" ) -> "Twist" :
62+ return Twist (** {field .name : getattr (self , field .name ) - getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
63+
64+ def __mul__ (self , other : float ) -> "Twist" :
65+ if isinstance (other , Twist ):
66+ return Twist (** {field .name : getattr (self , field .name ) * getattr (other , field .name ) for field in self .__dataclass_fields__ .values ()})
67+ return Twist (** {field .name : getattr (self , field .name ) * other for field in self .__dataclass_fields__ .values ()})
68+
69+ @dataclass (slots = True )
70+ class State :
71+ pose : Pose
72+ twist : Twist
73+
74+ def __add__ (self , other : "State" ) -> "State" :
75+ return State (pose = self .pose + other .pose , twist = self .twist + other .twist )
76+
77+ def __sub__ (self , other : "State" ) -> "State" :
78+ return State (pose = self .pose - other .pose , twist = self .twist - other .twist )
79+
80+
81+
82+ test_pose = Pose (x = 1.0 , y = 2.0 , z = 3.0 , roll = 0.1 , pitch = 0.2 , yaw = 0.3 )
83+ test_twist = Twist (linear_x = 0.1 , linear_y = 0.2 , linear_z = 0.3 , angular_x = 0.01 , angular_y = 0.02 , angular_z = 0.03 )
84+ test_pose2 = Pose (x = 0.1 , y = 0.2 , z = 0.3 , roll = 0.01 , pitch = 0.02 , yaw = 0.03 )
85+ test_twist2 = Twist (linear_x = 0.01 , linear_y = 0.02 , linear_z = 0.03 , angular_x = 0.001 , angular_y = 0.002 , angular_z = 0.003 )
86+
87+ euler = [1.0 , 0.0 , 0.0 ]
88+ quat = euler_to_quat (* euler )
89+ print (quat )
0 commit comments