88import json
99import itertools
1010
11+ import yaml
1112import pandas
1213import numpy
1314import structlog
@@ -285,76 +286,23 @@ def export_model(path, out):
285286 cmodel .save (name = 'harmodel' , format = 'csv' , file = out )
286287
287288
289+ def load_config (file_path ):
290+
291+ with open (file_path , 'r' ) as f :
292+ data = yaml .safe_load (f )
293+ return data
294+
288295def run_pipeline (run , hyperparameters , dataset ,
296+ config ,
289297 data_dir ,
290298 out_dir ,
291299 model_settings = dict (),
292300 n_splits = 5 ,
293301 features = 'timebased' ,
294302 ):
295303
296- dataset_config = {
297- 'uci_har' : dict (
298- groups = ['subject' , 'experiment' ],
299- data_columns = ['acc_x' , 'acc_y' , 'acc_z' ],
300- classes = [
301- #'STAND_TO_LIE',
302- #'SIT_TO_LIE',
303- #'LIE_TO_SIT',
304- #'STAND_TO_SIT',
305- #'LIE_TO_STAND',
306- #'SIT_TO_STAND',
307- 'STANDING' , 'LAYING' , 'SITTING' ,
308- 'WALKING' , 'WALKING_UPSTAIRS' , 'WALKING_DOWNSTAIRS' ,
309- ],
310- ),
311- 'pamap2' : dict (
312- groups = ['subject' ],
313- data_columns = ['hand_acceleration_16g_x' , 'hand_acceleration_16g_y' , 'hand_acceleration_16g_z' ],
314- classes = [
315- #'transient',
316- 'walking' , 'ironing' , 'lying' , 'standing' ,
317- 'Nordic_walking' , 'sitting' , 'vacuum_cleaning' ,
318- 'cycling' , 'ascending_stairs' , 'descending_stairs' ,
319- 'running' , 'rope_jumping' ,
320- ],
321- ),
322- 'har_exercise_1' : dict (
323- groups = ['file' ],
324- data_columns = ['x' , 'y' , 'z' ],
325- classes = [
326- #'mixed',
327- 'squat' , 'jumpingjack' , 'lunge' , 'other' ,
328- ],
329- ),
330- 'toothbrush_hussain2021' : dict (
331- groups = ['subject' ],
332- label_column = 'is_brushing' ,
333- time_column = 'elapsed' ,
334- data_columns = ['acc_x' , 'acc_y' , 'acc_z' ],
335- #data_columns = ['gravity_x', 'gravity_y', 'gravity_z'],
336- #data_columns = ['motion_x', 'motion_y', 'motion_z'],
337- classes = [
338- #'mixed',
339- 'True' , 'False' ,
340- ],
341- ),
342- 'toothbrush_jonnor' : dict (
343- groups = ['session' ],
344- label_column = 'is_brushing' ,
345- time_column = 'time' ,
346- data_columns = ['x' , 'y' , 'z' ],
347- #data_columns = ['gravity_x', 'gravity_y', 'gravity_z'],
348- #data_columns = ['motion_x', 'motion_y', 'motion_z'],
349- classes = [
350- #'mixed',
351- 'True' , 'False' ,
352- ],
353- ),
354- }
304+ dataset_config = load_config (config )
355305
356- if not dataset in dataset_config .keys ():
357- raise ValueError (f"Unknown dataset { dataset } " )
358306
359307 if not os .path .exists (out_dir ):
360308 os .makedirs (out_dir )
@@ -368,12 +316,12 @@ def run_pipeline(run, hyperparameters, dataset,
368316 #print(data.index.names)
369317 #print(data.columns)
370318
371- groups = dataset_config [dataset ][ 'groups' ]
372- data_columns = dataset_config [dataset ][ 'data_columns' ]
373- enabled_classes = dataset_config [dataset ][ 'classes' ]
374- label_column = dataset_config [ dataset ] .get ('label_column' , 'activity' )
375- time_column = dataset_config [ dataset ] .get ('time_column' , 'time' )
376- sensitivity = dataset_config [ dataset ] .get ('sensitivity' , 4.0 )
319+ groups = dataset_config ['groups' ]
320+ data_columns = dataset_config ['data_columns' ]
321+ enabled_classes = dataset_config ['classes' ]
322+ label_column = dataset_config .get ('label_column' , 'activity' )
323+ time_column = dataset_config .get ('time_column' , 'time' )
324+ sensitivity = dataset_config .get ('sensitivity' , 4.0 )
377325
378326 data [label_column ] = data [label_column ].astype (str )
379327
@@ -486,6 +434,8 @@ def parse():
486434
487435 parser .add_argument ('--dataset' , type = str , default = 'uci_har' ,
488436 help = 'Which dataset to use' )
437+ parser .add_argument ('--config' , type = str , default = 'data/configurations/uci_har.yaml' ,
438+ help = 'Which dataset/training config to use' )
489439 parser .add_argument ('--data-dir' , metavar = 'DIRECTORY' , type = str , default = './data/processed' ,
490440 help = 'Where the input data is stored' )
491441 parser .add_argument ('--out-dir' , metavar = 'DIRECTORY' , type = str , default = './' ,
@@ -506,9 +456,6 @@ def parse():
506456def main ():
507457
508458 args = parse ()
509- dataset = args .dataset
510- out_dir = args .out_dir
511- data_dir = args .data_dir
512459
513460 run_id = uuid .uuid4 ().hex .upper ()[0 :6 ]
514461
@@ -524,6 +471,7 @@ def main():
524471 }
525472
526473 results = run_pipeline (dataset = args .dataset ,
474+ config = args .config ,
527475 out_dir = args .out_dir ,
528476 data_dir = args .data_dir ,
529477 run = run_id ,
0 commit comments