|
| 1 | +--- |
| 2 | +Title: '.argmin()' |
| 3 | +Description: 'Returns the index of the minimum value in a PyTorch tensor, or along a specified dimension.' |
| 4 | +Subjects: |
| 5 | + - 'Computer Science' |
| 6 | + - 'Data Science' |
| 7 | +Tags: |
| 8 | + - 'Deep Learning' |
| 9 | + - 'Methods' |
| 10 | + - 'PyTorch' |
| 11 | + - 'Tensor' |
| 12 | +CatalogContent: |
| 13 | + - 'intro-to-py-torch-and-neural-networks' |
| 14 | + - 'paths/data-science' |
| 15 | +--- |
| 16 | + |
| 17 | +The **`.argmin()`** method in PyTorch returns the index of the minimum value in a flattened [tensor](https://www.codecademy.com/resources/docs/pytorch/tensors) tensor by default, or along a specified dimension. This method is commonly used in tasks such as finding the closest data point, selecting the best prediction, or identifying the least likely class in machine learning workflows. |
| 18 | + |
| 19 | +## Syntax |
| 20 | + |
| 21 | +```pseudo |
| 22 | +torch.argmin(input, dim=None, keepdim=False) |
| 23 | +``` |
| 24 | + |
| 25 | +**Parameters:** |
| 26 | + |
| 27 | +- `input` (Tensor): The input tensor to search for the minimum value. |
| 28 | +- `dim` (int, optional): The dimension to reduce. If not specified, the index of the minimum value in the flattened tensor is returned. |
| 29 | +- `keepdim` (bool, optional): Whether the output tensor retains the reduced dimension. Defaults to `False`. |
| 30 | + |
| 31 | +**Return value:** |
| 32 | + |
| 33 | +The `.argmin()` method returns a `LongTensor` containing the index or indices of the minimum value(s). If `dim` is not specified, a scalar tensor is returned. |
| 34 | + |
| 35 | +## Example |
| 36 | + |
| 37 | +This example shows how to use the `.argmin()` method to find the index of the minimum value in a 2D tensor: |
| 38 | + |
| 39 | +```py |
| 40 | +import torch |
| 41 | + |
| 42 | +# Define a 2D tensor |
| 43 | +tensor = torch.tensor([[8, 3, 5], |
| 44 | + [2, 7, 4]]) |
| 45 | + |
| 46 | +# Index of minimum in flattened tensor |
| 47 | +print(torch.argmin(tensor)) |
| 48 | + |
| 49 | +# Index of minimum along each column (dim=0) |
| 50 | +print(torch.argmin(tensor, dim=0)) |
| 51 | + |
| 52 | +# Index of minimum along each row (dim=1) |
| 53 | +print(torch.argmin(tensor, dim=1)) |
| 54 | +``` |
| 55 | + |
| 56 | +This example results in the following output: |
| 57 | + |
| 58 | +```shell |
| 59 | +tensor(3) |
| 60 | +tensor([1, 0, 1]) |
| 61 | +tensor([1, 0]) |
| 62 | +``` |
| 63 | + |
| 64 | +In this example: |
| 65 | + |
| 66 | +- **Flattened tensor**: The tensor is treated as `[8, 3, 5, 2, 7, 4]`, and the minimum value `2` is at index `3`. |
| 67 | +- **Along columns (`dim=0`)**: The minimum values in each column are `2`, `3`, and `4`, found in rows `1`, `0`, and `1`. |
| 68 | +- **Along rows (`dim=1`)**: The minimum values in each row are `3` (at index `1`) and `2` (at index `0`). |
0 commit comments