forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtorch-grayscale.py
More file actions
45 lines (37 loc) · 1.6 KB
/
torch-grayscale.py
File metadata and controls
45 lines (37 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python3
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# DOC: /max/api/python/torch.py
import max.torch
import numpy as np
import torch
from max.dtype import DType
from max.graph import ops
@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue): # noqa: ANN201
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor): # noqa: ANN201
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
# Define device first
device = "cuda" if torch.cuda.is_available() else "cpu"
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
print(f"Input shape: {img.shape}")
print(f"Output shape: {result.shape}")
print("Grayscale conversion completed successfully!")