Skip to content

Commit bcb7e29

Browse files
jvz37jalilAlvasekyondaMeta
authored
Modernize polynomial_custom_function: torch.accelerator and setup_context (#3885)
Fixes #3880 ## Description Modernizes the custom autograd function tutorial by: - Replacing hardcoded `torch.device("cpu")` with `torch.accelerator` for accelerator-agnostic device selection - Splitting legacy `forward(ctx, input)` into `forward(input)` + `setup_context(ctx, inputs, output)` per PyTorch 2.0+ recommended pattern Tested locally: 2000 iterations, loss converged, output identical to original. ## Checklist - [x] The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER") - [x] Only one issue is addressed in this pull request - [] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request. --------- Co-authored-by: jalilAlva <jalil.alva@welocalize.com> Co-authored-by: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com>
1 parent 43859b1 commit bcb7e29

1 file changed

Lines changed: 17 additions & 9 deletions

File tree

beginner_source/examples_autograd/polynomial_custom_function.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@ class LegendrePolynomial3(torch.autograd.Function):
2929
"""
3030

3131
@staticmethod
32-
def forward(ctx, input):
32+
def forward(input):
3333
"""
3434
In the forward pass we receive a Tensor containing the input and return
35-
a Tensor containing the output. ctx is a context object that can be used
36-
to stash information for backward computation. You can cache tensors for
37-
use in the backward pass using the ``ctx.save_for_backward`` method. Other
38-
objects can be stored directly as attributes on the ctx object, such as
39-
``ctx.my_object = my_object``. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_
35+
a Tensor containing the output. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_
4036
for further details.
4137
"""
42-
ctx.save_for_backward(input)
4338
return 0.5 * (5 * input ** 3 - 3 * input)
4439

40+
@staticmethod
41+
def setup_context(ctx, inputs, output):
42+
"""
43+
Store input for use in the backward pass using ``ctx.save_for_backward``.
44+
Other objects can be stored directly as attributes on the ctx object,
45+
such as ``ctx.my_object = my_object``.
46+
"""
47+
input, = inputs
48+
ctx.save_for_backward(input)
49+
4550
@staticmethod
4651
def backward(ctx, grad_output):
4752
"""
@@ -54,8 +59,11 @@ def backward(ctx, grad_output):
5459

5560

5661
dtype = torch.float
57-
device = torch.device("cpu")
58-
# device = torch.device("cuda:0") # Uncomment this to run on GPU
62+
device = (
63+
torch.accelerator.current_accelerator().type
64+
if torch.accelerator.is_available()
65+
else "cpu"
66+
)
5967

6068
# Create Tensors to hold input and outputs.
6169
# By default, requires_grad=False, which indicates that we do not need to

0 commit comments

Comments
 (0)