This guide demonstrates how to use SageMaker's distributed data parallel library (smdistributed.dataparallel) to train a TensorFlow 2 model on the MNIST dataset.
What is SageMaker Distributed Data Parallel Library?
Amazon SageMaker's distributed library allows you to train deep learning models faster and more efficiently. The smdistributed.dataparallel feature offers a distributed data parallel training framework for PyTorch, TensorFlow, and MXNet.
What this guide covers:
- Introduction: Briefly explains SageMaker's distributed data parallel library and its benefits.
- Dataset: Introduces the MNIST dataset, a popular benchmark for handwritten digit classification.
- SageMaker Role: Defines the IAM role required to create and run SageMaker training and hosting jobs.
- Model Training Script: Explains the Python script (
train_tensorflow_smdataparallel_mnist.py) used for training the TensorFlow model withsmdistributed.dataparallel. - SageMaker TensorFlow Estimator: Demonstrates how to configure a SageMaker TensorFlow Estimator object, including:
- Specifying the training script, instance type, instance count, and SageMaker session.
- Setting the
distributionstrategy to usesmdistributed.dataparallel.
- Training the Model: Trains the TensorFlow model using the SageMaker Estimator.
- Model Deployment: Deploys the trained model as an endpoint for real-time predictions.
- Inference: Shows how to use the deployed endpoint for making predictions on new data.
- Cleanup: Guides you on how to delete the endpoint if you don't intend to use it further.
Additional Notes
- Refer to
mnist_final_update.ipynbfor a detailed code example of the distribution strategy. - For best performance, it's recommended to use instance types that support Amazon Elastic Fabric Adapter (e.g., ml.p3dn.24xlarge, ml.p4d.24xlarge) when training with
smdistributed.dataparallel.
Further Resources
- TensorFlow in SageMaker: TensorFlow in SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/tf.html
- SageMaker distributed data parallel API Specification: SageMaker distributed data parallel API Specification: https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel-use-api.html
- SageMaker's Distributed Data Parallel Library: SageMaker's Distributed Data Parallel Library: https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html
- Modify a TensorFlow 2.x Training Script Using SMD Data Parallel: Modify a TensorFlow 2.x Training Script Using SMD Data Parallel: https://aws.amazon.com/tensorflow/