Skip to content

Commit 4fbd495

Browse files
Merge commit '06531e0c034be730c06dd09f9b306c7dc75e6517'
2 parents 7efceae + 06531e0 commit 4fbd495

4 files changed

Lines changed: 15 additions & 22 deletions

File tree

tinyml-tinyverse/tinyml_tinyverse/common/datasets/audio_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def _extract_lpc(self, raw_audio):
337337
frame_step=frame_step,
338338
)
339339

340-
lpc_features = self._compute_lpc_features(frames)
340+
lpc_features = self._compute_lpc_features(frames) #TODO implement LPC
341341

342342
return torch.tensor(lpc_features, dtype=torch.float32).unsqueeze(0)
343343

tinyml-tinyverse/tinyml_tinyverse/common/datasets/image_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from torch.utils.data import Dataset
1212
from torchvision import transforms
1313

14-
from ast import literal_eval
15-
1614
def _normalize_transform_list(value):
1715
if value is None:
1816
return []
@@ -38,7 +36,7 @@ def _normalize_transform_list(value):
3836
return list(value)
3937

4038
return [value]
41-
39+
4240
def _to_bool(val):
4341
if isinstance(val, bool):
4442
return val

tinyml-tinyverse/tinyml_tinyverse/references/audio_classification/train.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117

118118
dataset_loader_dict = {'GoogleSpeechCommandsDataset':GoogleSpeechCommandsDataset}
119119
dataset_load_state = {'dataset': None, 'dataset_test': None, 'train_sampler': None, 'test_sampler': None}
120+
_float_best_metric = None # best float accuracy; set on float run, read on QAT run
120121

121122

122123
def get_args_parser():
@@ -237,10 +238,7 @@ def generate_golden_vectors(output_dir, dataset, output_int, generic_model=False
237238
generate_test_vector(output_dir, header_file_info)
238239
generate_model_aux(output_dir, dataset)
239240

240-
def set_dataset_augmentation_enabled(dataset, enabled):
241-
if hasattr(dataset, "set_augmentation_enabled"):
242-
dataset.set_augmentation_enabled(enabled)
243-
241+
244242
def main(gpu, args):
245243
"""Main training function for classification."""
246244
logger, device = setup_training_environment(args, gpu, 'classification', __file__)
@@ -312,7 +310,6 @@ def main(gpu, args):
312310

313311
move_model_to_device(model, device, logger)
314312
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
315-
# logger.info(f"args.transforms = {args.transforms}"
316313
model, model_without_ddp, model_ema = setup_distributed_model(model, args, device)
317314
optimizer, lr_scheduler = setup_optimizer_and_scheduler(model, args)
318315
resume_from_checkpoint(model_without_ddp, optimizer, lr_scheduler, model_ema, args)
@@ -367,21 +364,14 @@ def main(gpu, args):
367364

368365
for epoch in range(args.start_epoch, args.epochs):
369366
if args.distributed:
370-
train_sampler.set_epoch(epoch)
371-
372-
set_dataset_augmentation_enabled(dataset, True)
373-
367+
train_sampler.set_epoch(epoch)
374368
utils.train_one_epoch_classification(
375369
model, criterion, optimizer, data_loader, device, epoch, None, args.apex, model_ema,
376370
print_freq=args.print_freq, phase=phase, num_classes=num_classes, dual_op=args.dual_op,
377371
is_ptq=True if (args.quantization_method in ['PTQ'] and args.quantization) else False,
378372
nn_for_feature_extraction=args.nn_for_feature_extraction)
379-
380-
set_dataset_augmentation_enabled(dataset, False)
381373
if not (args.quantization_method in ['PTQ'] and args.quantization):
382374
lr_scheduler.step()
383-
set_dataset_augmentation_enabled(dataset, False)
384-
set_dataset_augmentation_enabled(dataset_test, False)
385375
avg_accuracy, avg_f1, auc, avg_conf_matrix, predictions, ground_truth = utils.evaluate_classification(
386376
model, criterion, data_loader_test, device=device, transform=None, phase=phase,
387377
num_classes=num_classes, dual_op=args.dual_op, nn_for_feature_extraction=args.nn_for_feature_extraction)
@@ -397,9 +387,11 @@ def main(gpu, args):
397387
checkpoint = save_checkpoint(model_without_ddp, optimizer, lr_scheduler, epoch, args, model_ema)
398388
utils.save_on_master(checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))
399389

390+
if not args.quantization and args.auto_quantization:
391+
_float_best_metric = best['accuracy'] / 100.0
392+
logger.info(f"Stored float best accuracy for binary search: {_float_best_metric:.4f}")
393+
400394
# Log best epoch results
401-
set_dataset_augmentation_enabled(dataset, False)
402-
set_dataset_augmentation_enabled(dataset_test, False)
403395
logger = getLogger(f"root.main.{phase}.BestEpoch")
404396
logger.info("")
405397
logger.info("Printing statistics of best epoch:")
@@ -438,8 +430,6 @@ def main(gpu, args):
438430

439431
if args.gen_golden_vectors:
440432

441-
set_dataset_augmentation_enabled(dataset, False)
442-
set_dataset_augmentation_enabled(dataset_test, False)
443433
generate_golden_vector_dir(args.output_dir)
444434
output_int = get_output_int_flag(args)
445435
generate_golden_vectors(args.output_dir, dataset, output_int, args.generic_model, args.nn_for_feature_extraction)

tinyml-tinyverse/tinyml_tinyverse/references/image_classification/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117

118118
dataset_loader_dict = {'GenericImageDataset':GenericImageDataset}
119119
dataset_load_state = {'dataset': None, 'dataset_test': None, 'train_sampler': None, 'test_sampler': None}
120+
_float_best_metric = None # best float accuracy; set on float run, read on QAT run
120121

121122

122123
def get_args_parser():
@@ -325,7 +326,7 @@ def main(gpu, args):
325326

326327
move_model_to_device(model, device, logger)
327328
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
328-
# logger.info(f"args.transforms = {args.transforms}"
329+
329330
model, model_without_ddp, model_ema = setup_distributed_model(model, args, device)
330331
optimizer, lr_scheduler = setup_optimizer_and_scheduler(model, args)
331332
resume_from_checkpoint(model_without_ddp, optimizer, lr_scheduler, model_ema, args)
@@ -410,6 +411,10 @@ def main(gpu, args):
410411
checkpoint = save_checkpoint(model_without_ddp, optimizer, lr_scheduler, epoch, args, model_ema)
411412
utils.save_on_master(checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))
412413

414+
if not args.quantization and args.auto_quantization:
415+
_float_best_metric = best['accuracy'] / 100.0
416+
logger.info(f"Stored float best accuracy for binary search: {_float_best_metric:.4f}")
417+
413418
# Log best epoch results
414419
set_dataset_augmentation_enabled(dataset, False)
415420
set_dataset_augmentation_enabled(dataset_test, False)

0 commit comments

Comments
 (0)