File: kd_squeezenet_template.py
An important aspect of machine learning is hyperparameter tuning. AML's HyperDrive is a hyperparameter tuning service, offering:
- Random, Grid and Bayesian parameter sampling
- manages the jobs creation and monitoring process for the user
- early termination
Copy the below code snippets into the template file, and rename it by removing _template from the filename.
parser.add_argument('--learning_rate', default=1e-2, help='learning rate', type=float, required=False)
parser.add_argument('--weight_decay', default=1e-2, help='weight_decay', type=float, required=False)
parser.add_argument('--temperature', default=5.0, help='temperature', type=float, required=False)
parser.add_argument('--lambda_const', default=2e-1, help='lambda_const', type=float, required=False)
parser.add_argument('--momentum', default=9e-1, help='momentum', type=float, required=False)
parser.add_argument('--batch_size', dest="batch_size", default=64, help='Batch size', type=int, required=False)
parser.add_argument('--transfer_learning', dest="transfer_learning", default="False", help='use the benchmark model and perform transfer learning', type=str, required=False)
transfer_learning = str2bool(args.transfer_learning)
learning_rate = args.learning_rate
weight_decay = args.weight_decay
temperature = args.temperature
lambda_const = args.lambda_const
momentum = args.momentum
batch_size = args.batch_size
if remote_execution:
print("Running on remote compute target:", remote_execution)
from azureml.core import VERSION
print("azureml.core.VERSION", VERSION)
from azureml.core import Run
# start an Azure ML run
run = Run.get_context()
run.log('learning_rate', learning_rate)
run.log('weight_decay', weight_decay)
run.log('temperature', temperature)
run.log('lambda_const', lambda_const)
run.log('momentum', momentum)
run.log('batch_size', batch_size)
run.log('transfer_learning', transfer_learning)
if remote_execution:
run.log_image('soft target dist', plot=plt)
else:
plt.savefig('soft_target_dist.png')
plt.close()
# log progress to AML workspace
if remote_execution:
class LogRunMetrics(Callback):
# callback at the end of every epoch
def on_epoch_end(self, epoch, log):
# log a value repeated which creates a list
run.log('val_loss', log['val_loss'])
run.log('loss', log['loss'])
callbacks.append(LogRunMetrics())
if remote_execution:
run.log('final_val_loss', model.history.history['val_loss'][-1])
run.log('final_val_accuracy', model.history.history['val_accuracy'][-1])
if remote_execution:
run.log_image('crossentropy', plot=plt)
else:
plt.savefig('crossentropy.png')
plt.close()
Back to main tutorial page