1+ import os
2+ import numpy as np
3+ from tqdm import tqdm
4+ import utils .motion_modules as motion_modules
5+ from utils .bvh_motion import Motion
6+ import multiprocessing
7+
8+
9+ def calculate_metrics (args ):
10+ bvh_path , preset_path , metric_names , save_dic = args
11+ motion = Motion .load_bvh (bvh_path )
12+
13+ all_frame_idx = np .arange (motion .frame_num )
14+ if 'LeftCollar' in motion .names :
15+ forward_angle = motion_modules .extract_forward (motion , all_frame_idx , 'LeftCollar' , 'RightCollar' , 'LeftHip' , 'RightHip' )
16+ elif 'RightUpLeg' in motion .names :
17+ forward_angle = motion_modules .extract_forward (motion , all_frame_idx , 'LeftShoulder' , 'RightShoulder' , 'LeftUpLeg' , 'RightUpLeg' )
18+
19+ forward_angle = np .rad2deg (forward_angle )
20+ preset_traj , preset_orien = [], []
21+ with open (preset_path , 'r' ) as f :
22+ for line in f :
23+ values = line .strip ().split (',' )
24+ traj_x , traj_z = float (values [0 ]), float (values [2 ])
25+ traj_angle = np .arctan2 (traj_z , traj_x )
26+ dirc_x , dirc_z = float (values [3 ]), float (values [5 ])
27+ dirc_angle = np .arctan2 (dirc_z , - dirc_x )
28+ preset_traj .append (traj_angle )
29+ preset_orien .append (dirc_angle )
30+ preset_traj , preset_orien = np .rad2deg (np .array (preset_traj )), np .rad2deg (np .array (preset_orien ))
31+
32+ if motion .frame_num > len (preset_traj ):
33+ motion = motion_modules .temporal_scale (motion , 2 )
34+ forward_angle = forward_angle [::2 ]
35+
36+ preset_traj = preset_traj [:motion .frame_num - 1 ]
37+ preset_orien = preset_orien [:motion .frame_num ]
38+
39+ traj_pos = motion .positions [:, 0 , [0 , 2 ]]
40+ traj_pos [:, 0 ] *= - 1
41+ traj_pos_diff = np .diff (traj_pos , axis = 0 )
42+ traj_angle = np .rad2deg (np .arctan2 (traj_pos_diff [:, 1 ], traj_pos_diff [:, 0 ]))
43+
44+ metric_value = []
45+ for metric_name in metric_names :
46+ if metric_name == 'traj_error' :
47+ value = np .abs (np .mean (preset_traj - traj_angle ))
48+ # value = np.min([value, 90 - value, 180 - value])
49+ metric_value .append (np .abs (value ))
50+ elif metric_name == 'orien_error' :
51+ value = np .abs (np .mean (preset_orien - forward_angle ))
52+ # value = np.min([value, 90 - value, 180 - value])
53+ metric_value .append (np .abs (value ))
54+
55+ save_dic [bvh_path ] = metric_value
56+
57+
58+ if __name__ == '__main__' :
59+
60+ test_folders = [
61+ ["data/exp1_mann+dp" , "data/preset.txt" ],
62+ ["data/exp1_mann+lp" , "data/preset.txt" ],
63+ ["data/exp1_matching" , "data/preset.txt" ],
64+ ["data/exp1_moglow" , "data/preset.txt" ],
65+ ["data/exp1_ours" , "data/preset.txt" ],
66+ ["data/exp1_ours_mlp" , "data/preset.txt" ],
67+ ]
68+
69+ select_metric = ['traj_error' , 'orien_error' ]
70+ result_path = os .path .join ('./result_recording' , 'trajectory_alignment.txt' )
71+
72+ if not os .path .exists (result_path ):
73+ with open (result_path , 'w' ) as f :
74+ f .write ('Metrics: \t \t ' )
75+ for metric_name in select_metric :
76+ f .write ('%s\t \t ' % metric_name )
77+ f .write ('\n ' )
78+ f .write ('---------------------------------\n ' )
79+
80+ for test_folder , preset_path in test_folders :
81+ test_file_list = [os .path .join (test_folder , f ) for f in os .listdir (test_folder ) if f .endswith ('.bvh' )]
82+
83+ calculate_metrics ((test_file_list [0 ], preset_path , select_metric , {}))
84+
85+ metric_dic = multiprocessing .Manager ().dict ()
86+ args = [(test_file , preset_path , select_metric , metric_dic ) for test_file in test_file_list ]
87+
88+ num_processes = multiprocessing .cpu_count ()
89+ with multiprocessing .Pool (processes = num_processes ) as pool :
90+ pool .map (calculate_metrics , args )
91+
92+ metric_dic = dict (metric_dic )
93+
94+ avg_list = []
95+ for metric_name in select_metric :
96+ metric_value = []
97+ for test_file in test_file_list :
98+ metric_value .append (metric_dic [test_file ][select_metric .index (metric_name )])
99+ # remove nan in metric_value
100+ metric_value = [value for value in metric_value if not np .isnan (value )]
101+ avg_list .append (np .mean (metric_value ))
102+
103+ with open (result_path , 'a' ) as f :
104+ f .write (test_folder + '\t ' )
105+ for metric_name in select_metric :
106+ f .write ('%.4f\t \t ' % avg_list [select_metric .index (metric_name )])
107+ f .write ('\n ' )
108+
109+ print ('Finish %s, metrics: %s' % (test_folder , avg_list ))
110+
111+
112+
0 commit comments