117117
118118dataset_loader_dict = {'GoogleSpeechCommandsDataset' :GoogleSpeechCommandsDataset }
119119dataset_load_state = {'dataset' : None , 'dataset_test' : None , 'train_sampler' : None , 'test_sampler' : None }
120+ _float_best_metric = None # best float accuracy; set on float run, read on QAT run
120121
121122
122123def get_args_parser ():
@@ -237,10 +238,7 @@ def generate_golden_vectors(output_dir, dataset, output_int, generic_model=False
237238 generate_test_vector (output_dir , header_file_info )
238239 generate_model_aux (output_dir , dataset )
239240
240- def set_dataset_augmentation_enabled (dataset , enabled ):
241- if hasattr (dataset , "set_augmentation_enabled" ):
242- dataset .set_augmentation_enabled (enabled )
243-
241+
244242def main (gpu , args ):
245243 """Main training function for classification."""
246244 logger , device = setup_training_environment (args , gpu , 'classification' , __file__ )
@@ -312,7 +310,6 @@ def main(gpu, args):
312310
313311 move_model_to_device (model , device , logger )
314312 criterion = nn .CrossEntropyLoss (label_smoothing = args .label_smoothing )
315- # logger.info(f"args.transforms = {args.transforms}"
316313 model , model_without_ddp , model_ema = setup_distributed_model (model , args , device )
317314 optimizer , lr_scheduler = setup_optimizer_and_scheduler (model , args )
318315 resume_from_checkpoint (model_without_ddp , optimizer , lr_scheduler , model_ema , args )
@@ -367,21 +364,14 @@ def main(gpu, args):
367364
368365 for epoch in range (args .start_epoch , args .epochs ):
369366 if args .distributed :
370- train_sampler .set_epoch (epoch )
371-
372- set_dataset_augmentation_enabled (dataset , True )
373-
367+ train_sampler .set_epoch (epoch )
374368 utils .train_one_epoch_classification (
375369 model , criterion , optimizer , data_loader , device , epoch , None , args .apex , model_ema ,
376370 print_freq = args .print_freq , phase = phase , num_classes = num_classes , dual_op = args .dual_op ,
377371 is_ptq = True if (args .quantization_method in ['PTQ' ] and args .quantization ) else False ,
378372 nn_for_feature_extraction = args .nn_for_feature_extraction )
379-
380- set_dataset_augmentation_enabled (dataset , False )
381373 if not (args .quantization_method in ['PTQ' ] and args .quantization ):
382374 lr_scheduler .step ()
383- set_dataset_augmentation_enabled (dataset , False )
384- set_dataset_augmentation_enabled (dataset_test , False )
385375 avg_accuracy , avg_f1 , auc , avg_conf_matrix , predictions , ground_truth = utils .evaluate_classification (
386376 model , criterion , data_loader_test , device = device , transform = None , phase = phase ,
387377 num_classes = num_classes , dual_op = args .dual_op , nn_for_feature_extraction = args .nn_for_feature_extraction )
@@ -397,9 +387,11 @@ def main(gpu, args):
397387 checkpoint = save_checkpoint (model_without_ddp , optimizer , lr_scheduler , epoch , args , model_ema )
398388 utils .save_on_master (checkpoint , os .path .join (args .output_dir , 'checkpoint.pth' ))
399389
390+ if not args .quantization and args .auto_quantization :
391+ _float_best_metric = best ['accuracy' ] / 100.0
392+ logger .info (f"Stored float best accuracy for binary search: { _float_best_metric :.4f} " )
393+
400394 # Log best epoch results
401- set_dataset_augmentation_enabled (dataset , False )
402- set_dataset_augmentation_enabled (dataset_test , False )
403395 logger = getLogger (f"root.main.{ phase } .BestEpoch" )
404396 logger .info ("" )
405397 logger .info ("Printing statistics of best epoch:" )
@@ -438,8 +430,6 @@ def main(gpu, args):
438430
439431 if args .gen_golden_vectors :
440432
441- set_dataset_augmentation_enabled (dataset , False )
442- set_dataset_augmentation_enabled (dataset_test , False )
443433 generate_golden_vector_dir (args .output_dir )
444434 output_int = get_output_int_flag (args )
445435 generate_golden_vectors (args .output_dir , dataset , output_int , args .generic_model , args .nn_for_feature_extraction )
0 commit comments