Skip to content

Commit 46e4ab7

Browse files
authored
Merge branch 'main' into data-loading-guide
2 parents d74e331 + 51bdefb commit 46e4ab7

3 files changed

Lines changed: 28 additions & 22 deletions

File tree

beginner_source/basics/data_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@
4848
import torch
4949
from torch.utils.data import Dataset
5050
from torchvision import datasets
51-
from torchvision.transforms import ToTensor
51+
from torchvision.transforms import v2
5252
import matplotlib.pyplot as plt
5353

5454

5555
training_data = datasets.FashionMNIST(
5656
root="data",
5757
train=True,
5858
download=True,
59-
transform=ToTensor()
59+
transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
6060
)
6161

6262
test_data = datasets.FashionMNIST(
6363
root="data",
6464
train=False,
6565
download=True,
66-
transform=ToTensor()
66+
transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
6767
)
6868

6969

beginner_source/basics/optimization_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@
2828
from torch import nn
2929
from torch.utils.data import DataLoader
3030
from torchvision import datasets
31-
from torchvision.transforms import ToTensor
31+
from torchvision.transforms import v2
3232

3333
training_data = datasets.FashionMNIST(
3434
root="data",
3535
train=True,
3636
download=True,
37-
transform=ToTensor()
37+
transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
3838
)
3939

4040
test_data = datasets.FashionMNIST(
4141
root="data",
4242
train=False,
4343
download=True,
44-
transform=ToTensor()
44+
transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
4545
)
4646

4747
train_dataloader = DataLoader(training_data, batch_size=64)

beginner_source/basics/transforms_tutorial.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,47 @@
2323
2424
The FashionMNIST features are in PIL Image format, and the labels are integers.
2525
For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors.
26-
To make these transformations, we use ``ToTensor`` and ``Lambda``.
26+
To make these transformations, we use the ``torchvision.transforms.v2`` API along with ``torch.nn.functional.one_hot``.
2727
"""
2828

2929
import torch
30+
import torch.nn.functional as F
3031
from torchvision import datasets
31-
from torchvision.transforms import ToTensor, Lambda
32+
from torchvision.transforms import v2
3233

3334
ds = datasets.FashionMNIST(
3435
root="data",
3536
train=True,
3637
download=True,
37-
transform=ToTensor(),
38-
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
38+
transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
39+
target_transform=v2.Lambda(
40+
lambda y: F.one_hot(torch.tensor(y), num_classes=10).float()
41+
),
3942
)
4043

4144
#################################################
42-
# ToTensor()
45+
# ToImage() and ToDtype()
4346
# -------------------------------
4447
#
45-
# `ToTensor <https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor>`_
46-
# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales
47-
# the image's pixel intensity values in the range [0., 1.]
48+
# The ``torchvision.transforms.v2`` API replaces the legacy ``ToTensor`` transform with a two-step pipeline.
49+
# `v2.ToImage <https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.ToImage.html>`_
50+
# converts a PIL image or NumPy ``ndarray`` into a ``torchvision.tv_tensors.Image`` tensor, and
51+
# `v2.ToDtype <https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.ToDtype.html>`_
52+
# with ``scale=True`` casts it to ``float32`` and scales the pixel intensity values to the range [0., 1.].
4853
#
4954

5055
##############################################
5156
# Lambda Transforms
5257
# -------------------------------
5358
#
54-
# Lambda transforms apply any user-defined lambda function. Here, we define a function
55-
# to turn the integer into a one-hot encoded tensor.
56-
# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
57-
# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ which assigns a
58-
# ``value=1`` on the index as given by the label ``y``.
59+
# Lambda transforms apply any user-defined lambda function. Here, we use
60+
# `torch.nn.functional.one_hot <https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html>`_
61+
# to turn the integer label into a one-hot encoded tensor of size 10 (the number of labels in our dataset),
62+
# then cast it to ``float`` to match the expected dtype.
5963

60-
target_transform = Lambda(lambda y: torch.zeros(
61-
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
64+
target_transform = v2.Lambda(
65+
lambda y: F.one_hot(torch.tensor(y), num_classes=10).float()
66+
)
6267

6368
######################################################################
6469
# --------------
@@ -67,4 +72,5 @@
6772
#################################################################
6873
# Further Reading
6974
# ~~~~~~~~~~~~~~~~~
70-
# - `torchvision.transforms API <https://pytorch.org/vision/stable/transforms.html>`_
75+
# - `Getting started with transforms v2 <https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html>`_
76+
# - `torchvision.transforms.v2 API <https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended>`_

0 commit comments

Comments
 (0)