2828Data loading is often a critical bottleneck in deep learning pipelines. While
2929GPUs can process batches extremely quickly, inefficient data loading can leave
3030expensive hardware idle, waiting for the next batch of data. This tutorial
31- covers best practises and some techniques for optimizing your data loading configuration to
31+ covers best practices and some techniques for optimizing your data loading configuration to
3232maximize training throughput.
3333
3434We'll explore the key parameters of PyTorch's DataLoader and provide practical
4444import torch .nn as nn
4545from torch .utils .data import DataLoader , Dataset
4646
47- # Check if CUDA is available
4847device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4948print (f"Using device: { device } " )
5049
@@ -131,12 +130,10 @@ def __getitems__(self, indices):
131130# many epochs — making persistent_workers' benefit visible across
132131# epoch boundaries.
133132
134- # Dataset for progressive optimization benchmarks.
135133benchmark_dataset = SyntheticDataset (size = 512 , feature_dim = 224 , transform_delay = 0.005 )
136134
137135
138136class SmallTransformerModel (nn .Module ):
139- """A model with conv + transformer layers for realistic GPU compute."""
140137
141138 def __init__ (self ):
142139 super ().__init__ ()
@@ -147,7 +144,6 @@ def __init__(self):
147144 nn .ReLU (),
148145 nn .AdaptiveAvgPool2d ((7 , 7 )),
149146 )
150- # Transformer encoder on flattened spatial tokens
151147 encoder_layer = nn .TransformerEncoderLayer (
152148 d_model = 64 , nhead = 4 , dim_feedforward = 128 , batch_first = True
153149 )
@@ -286,7 +282,7 @@ def benchmark_batch_size(batch_size, num_batches=10):
286282 for i , (data , labels ) in enumerate (loader ):
287283 if i >= num_batches :
288284 break
289- data = data .to (device )
285+ data = data .to (device , non_blocking = True )
290286 _ = data .sum ()
291287 if torch .cuda .is_available ():
292288 torch .cuda .synchronize ()
@@ -603,7 +599,7 @@ def __next__(self):
603599prev_time = batched_time
604600
605601######################################################################
606- # The ``in_order`` Parameter
602+ # ``in_order`` parameter
607603# --------------------------
608604#
609605# By default (``in_order=True``), the DataLoader returns batches in
@@ -732,7 +728,7 @@ def __next__(self):
732728# import torch.multiprocessing as mp
733729# mp.set_sharing_strategy('file_system')
734730#
735- # **5 . Clean up leaked shared memory:**
731+ # **4 . Clean up leaked shared memory:**
736732#
737733# .. code-block:: bash
738734#
@@ -759,52 +755,34 @@ def __next__(self):
759755#
760756# .. list-table::
761757# :header-rows: 1
762- # :widths: 50 20 15 15
758+ # :widths: 55 20 20
763759#
764760# * - Configuration
765- # - Time
766761# - vs Baseline
767762# - vs Previous
768763# * - Baseline (num_workers=0, no pinning)
769- # - ~32s
770764# - 1.00x
771765# - —
772766# * - \+ num_workers=4, prefetch_factor=2
773- # - ~12s
774767# - ~2.7x
775768# - ~2.7x
776769# * - \+ pin_memory=True
777- # - ~11.5s
778770# - ~2.8x
779771# - ~1.0x
780772# * - \+ persistent_workers=True
781- # - ~9s
782773# - ~3.7x
783774# - ~1.3x
784775# * - \+ DataPrefetcher (H2D overlap)
785- # - ~9s
786776# - ~3.6x
787777# - ~1.0x
788778# * - \+ __getitems__ (batched fetching)
789- # - ~3s
790779# - ~10x
791780# - ~2.9x
792781#
793782# .. note::
794- # These results are based on our benchmark dataset
783+ # These results are based on our benchmark dataset.
795784# Actual speedups will vary depending on your specific
796785# workload, hardware, dataset size, and transform complexity.
797- #
798- # **Key takeaways:**
799- #
800- # - **Multiprocessing** (``num_workers > 0``) is often the biggest lever
801- # - **pin_memory + non_blocking** enables faster CPU-to-GPU transfers
802- # - **persistent_workers** eliminates epoch-boundary restart overhead
803- # - **__getitems__** enables batched fetching at the dataset level — can provide
804- # the largest speedup when your dataset supports vectorized I/O or bulk queries
805- # - **Prefetcing data** overlaps H2D transfer with compute (best when data
806- # loading is slow relative to GPU compute)
807- # - Always benchmark your specific workload and hardware
808786
809787######################################################################
810788# Summary and Best Practices
@@ -830,6 +808,19 @@ def __next__(self):
830808# 7. **Use ``file_system`` sharing strategy** when hitting file descriptor limits.
831809#
832810
811+ ######################################################################
812+ # Conclusion
813+ # ----------
814+ #
815+ # In this tutorial, we learned how to progressively optimize a PyTorch
816+ # data loading pipeline — from a naive single-process baseline to a
817+ # fully optimized configuration using multiprocessing workers, pinned
818+ # memory, persistent workers, CUDA stream-based prefetching, and batched
819+ # dataset fetching with ``__getitems__``. Each optimization targets a
820+ # different bottleneck, and together they can yield an order-of-magnitude
821+ # improvement in throughput. These should be considered best practices
822+ # and performance is dependent on the specific workload.
823+
833824######################################################################
834825# Additional Resources
835826# --------------------
0 commit comments