Skip to content
This repository was archived by the owner on Feb 27, 2026. It is now read-only.

Commit 6485e5f

Browse files
update docs for .max() method
1 parent e3823f9 commit 6485e5f

1 file changed

Lines changed: 36 additions & 24 deletions

File tree

  • content/pytorch/concepts/tensor-operations/terms/max

content/pytorch/concepts/tensor-operations/terms/max/max.md

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,24 @@ CatalogContent:
1414
- 'paths/data-science'
1515
---
1616

17-
The **`.max()`** method in PyTorch returns the maximum value from a [tensor](https://www.codecademy.com/resources/docs/pytorch/tensors). It can find the maximum value across the entire tensor. This method is commonly used in data analysis, finding peak values, and various neural network operations.
17+
The **`.max()`** method in PyTorch returns the maximum value from a [tensor](https://www.codecademy.com/resources/docs/pytorch/tensors). It can find the maximum value across the entire tensor or along a specified dimension. This method is commonly used in data analysis, finding peak values, and various neural network operations.
1818

1919
## Syntax
2020

2121
```pseudo
22-
torch.max(input) → Tensor
22+
torch.max(input, dim=None, keepdim=False) → Tensor or (Tensor, LongTensor)
2323
```
2424

2525
**Parameters:**
2626

2727
- `input` (Tensor): The input tensor.
28+
- `dim` (int, optional): The dimension along which to find the maximum values. If not specified, returns the maximum value of the entire tensor.
29+
- `keepdim` (bool, optional): Whether the output tensor retains the reduced dimension. Defaults to `False`.
2830

2931
**Return value:**
3032

31-
Returns a tensor containing the maximum value from the `input`.
33+
- When `dim` is not specified: Returns a tensor containing the single maximum value from the entire tensor.
34+
- When `dim` is specified: Returns a named tuple `(values, indices)` where `values` contains the maximum values along the specified dimension, and `indices` contains the indices of those maximum values.
3235

3336
## Example
3437

@@ -37,47 +40,56 @@ The following example demonstrates how to use the `.max()` method to find the ma
3740
```py
3841
import torch
3942

40-
# Create a tensor with various values
41-
tensor = torch.tensor([1.5, -2.3, 0.0, 4.8, -1.2])
43+
# Create a 2D tensor
44+
tensor = torch.tensor([[1.5, -2.3, 0.0],
45+
[4.8, -1.2, 3.6]])
4246

43-
# Find the maximum value using the method form
44-
max_value = tensor.max()
47+
# Find the maximum value of the entire tensor
48+
max_value = torch.max(tensor)
4549

46-
# Alternative: use the functional form
47-
max_functional = torch.max(tensor)
50+
# Find maximum values along each column (dim=0)
51+
max_cols = torch.max(tensor, dim=0)
52+
53+
# Find maximum values along each row (dim=1)
54+
max_rows = torch.max(tensor, dim=1)
4855

4956
print("Original Tensor:")
5057
print(tensor)
5158

52-
print("\nMaximum Value (using .max()):")
59+
print("\nMaximum Value (entire tensor):")
5360
print(max_value)
5461

55-
print("\nMaximum Value (using torch.max()):")
56-
print(max_functional)
62+
print("\nMaximum Values (along columns, dim=0):")
63+
print("Values:", max_cols.values)
64+
print("Indices:", max_cols.indices)
5765

58-
print("\nMaximum as Python number (using .item()):")
59-
print(max_value.item())
66+
print("\nMaximum Values (along rows, dim=1):")
67+
print("Values:", max_rows.values)
68+
print("Indices:", max_rows.indices)
6069
```
6170

6271
This example results in the following output:
6372

6473
```shell
6574
Original Tensor:
66-
tensor([ 1.5000, -2.3000, 0.0000, 4.8000, -1.2000])
75+
tensor([[ 1.5000, -2.3000, 0.0000],
76+
[ 4.8000, -1.2000, 3.6000]])
6777

68-
Maximum Value (using .max()):
78+
Maximum Value (entire tensor):
6979
tensor(4.8000)
7080

71-
Maximum Value (using torch.max()):
72-
tensor(4.8000)
81+
Maximum Values (along columns, dim=0):
82+
Values: tensor([4.8000, -1.2000, 3.6000])
83+
Indices: tensor([1, 1, 1])
7384

74-
Maximum as Python number (using .item()):
75-
4.800000190734863
85+
Maximum Values (along rows, dim=1):
86+
Values: tensor([1.5000, 4.8000])
87+
Indices: tensor([0, 0])
7688
```
7789

7890
In this example:
7991

80-
- The tensor contains five values: `1.5`, `-2.3`, `0.0`, `4.8`, and `-1.2`
81-
- The `.max()` method identifies `4.8` as the maximum value in the tensor
82-
- Both `.max()` and `torch.max()` produce identical results: `tensor(4.8000)`
83-
- The `.item()` method converts the tensor result to a Python float: `4.800000190734863`
92+
- **Entire tensor**: The maximum value across all elements is `4.8000`.
93+
- **Along columns (`dim=0`)**: The maximum values in each column are `4.8000`, `-1.2000`, and `3.6000`, all found in row `1` (index `1`).
94+
- **Along rows (`dim=1`)**: The maximum values in each row are `1.5000` (at index `0`) and `4.8000` (at index `0`).
95+
- When `dim` is specified, the method returns both the maximum values and their indices as a named tuple.

0 commit comments

Comments
 (0)