Skip to content

Commit 76e139a

Browse files
committed
fixed acceleration
1 parent cd1fd29 commit 76e139a

4 files changed

Lines changed: 21 additions & 6 deletions

File tree

amadeusgpt/analysis_objects/animal.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_keypoint_names(self):
122122
def query_states(self, query: str) -> ndarray:
123123
assert query in [
124124
"speed",
125-
"acceleration",
125+
"acceleration_mag",
126126
"bodypart_pairwise_distance",
127127
], f"{query} is not supported"
128128

@@ -172,8 +172,9 @@ def get_acceleration_mag(self) -> ndarray:
172172
Returns the magnitude of the acceleration vector
173173
"""
174174
accelerations = self.get_acceleration()
175-
acceleration_mag = np.linalg.norm(accelerations, axis=-1)
176-
assert len(acceleration_mag.shape) == 2
175+
acceleration_mag = np.linalg.norm(accelerations, axis=-1)
176+
acceleration_mag = np.expand_dims(acceleration_mag, axis=-1)
177+
assert len(acceleration_mag.shape) == 3
177178
return acceleration_mag
178179

179180
def get_bodypart_wise_relation(self):
@@ -230,5 +231,18 @@ def calc_head_cs(self):
230231
# unit testing the shape of kinematics data
231232
# acceleration, acceleration_mag, velocity, speed, and keypoints
232233

234+
from amadeusgpt.config import Config
235+
from amadeusgpt.main import AMADEUS
236+
config = Config("/Users/shaokaiye/AmadeusGPT-dev/amadeusgpt/configs/MausHaus_template.yaml")
237+
amadeus = AMADEUS(config)
238+
analysis = amadeus.get_analysis()
239+
# get an instance of animal
240+
animal = analysis.animal_manager.get_animals()[0]
233241

242+
print ("velocity shape", animal.get_velocity().shape)
243+
print ("speed shape", animal.get_speed().shape)
244+
print ("acceleration shape", animal.get_acceleration().shape)
245+
print ("acceleration_mag shape", animal.get_acceleration_mag().shape)
246+
247+
print(animal.query_states("acceleration_mag").shape)
234248

amadeusgpt/analysis_objects/relationship.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from scipy.spatial.distance import cdist
77

88
from .base import AnalysisObject
9-
from .object import AnimalSeq, Object
9+
from .object import Object
10+
from .animal import AnimalSeq
1011

1112

1213
class Orientation(IntEnum):

amadeusgpt/managers/animal_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from numpy import ndarray
99

10-
from amadeusgpt.analysis_objects.object import AnimalSeq
10+
from amadeusgpt.analysis_objects.animal import AnimalSeq
1111
from amadeusgpt.programs.api_registry import (register_class_methods,
1212
register_core_api)
1313

amadeusgpt/managers/visual_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from matplotlib.patches import Wedge
99

1010
from amadeusgpt.analysis_objects.event import BaseEvent
11-
from amadeusgpt.analysis_objects.object import AnimalSeq
11+
from amadeusgpt.analysis_objects.animal import AnimalSeq
1212
from amadeusgpt.analysis_objects.visualization import (EventVisualization,
1313
GraphVisualization,
1414
KeypointVisualization,

0 commit comments

Comments
 (0)