@@ -122,7 +122,7 @@ def get_keypoint_names(self):
122122 def query_states (self , query : str ) -> ndarray :
123123 assert query in [
124124 "speed" ,
125- "acceleration " ,
125+ "acceleration_mag " ,
126126 "bodypart_pairwise_distance" ,
127127 ], f"{ query } is not supported"
128128
@@ -172,8 +172,9 @@ def get_acceleration_mag(self) -> ndarray:
172172 Returns the magnitude of the acceleration vector
173173 """
174174 accelerations = self .get_acceleration ()
175- acceleration_mag = np .linalg .norm (accelerations , axis = - 1 )
176- assert len (acceleration_mag .shape ) == 2
175+ acceleration_mag = np .linalg .norm (accelerations , axis = - 1 )
176+ acceleration_mag = np .expand_dims (acceleration_mag , axis = - 1 )
177+ assert len (acceleration_mag .shape ) == 3
177178 return acceleration_mag
178179
179180 def get_bodypart_wise_relation (self ):
@@ -230,5 +231,18 @@ def calc_head_cs(self):
230231 # unit testing the shape of kinematics data
231232 # acceleration, acceleration_mag, velocity, speed, and keypoints
232233
234+ from amadeusgpt .config import Config
235+ from amadeusgpt .main import AMADEUS
236+ config = Config ("/Users/shaokaiye/AmadeusGPT-dev/amadeusgpt/configs/MausHaus_template.yaml" )
237+ amadeus = AMADEUS (config )
238+ analysis = amadeus .get_analysis ()
239+ # get an instance of animal
240+ animal = analysis .animal_manager .get_animals ()[0 ]
233241
242+ print ("velocity shape" , animal .get_velocity ().shape )
243+ print ("speed shape" , animal .get_speed ().shape )
244+ print ("acceleration shape" , animal .get_acceleration ().shape )
245+ print ("acceleration_mag shape" , animal .get_acceleration_mag ().shape )
246+
247+ print (animal .query_states ("acceleration_mag" ).shape )
234248
0 commit comments