forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph.py
More file actions
67 lines (55 loc) · 2.53 KB
/
graph.py
File metadata and controls
67 lines (55 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
import torch
from max.graph import TensorValue
from max.torch import graph_op
@graph_op
def max_matmul(a: TensorValue, b: TensorValue): # noqa: ANN201
"""Custom PyTorch operation built using an internal MAX graph."""
return a @ b # Same as ops.matmul(a, b)
@torch.compile
def matmul_max(a: torch.Tensor, b: torch.Tensor): # noqa: ANN201
"""Wrapper function that calls the MAX matmul operation."""
# Create output tensor with appropriate shape
output = a.new_empty(a.shape[0], b.shape[1])
max_matmul(output, a, b) # Call as destination-passing style
return output
if __name__ == "__main__":
# Test on both CPU and GPU if available
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
for device in devices:
print(f"\n{'=' * 50}")
print(f"Testing on device: {device}")
print("=" * 50)
# Create random input tensors
M, K, N = 128, 256, 512
a = torch.randn(M, K, device=device, dtype=torch.float32)
b = torch.randn(K, N, device=device, dtype=torch.float32)
# Compute matmul using MAX
result_max = matmul_max(a, b)
# Compute matmul using PyTorch for comparison
result_torch = torch.matmul(a, b)
# Verify the results match
if torch.allclose(result_max, result_torch, rtol=1e-1, atol=1e-1):
print("✓ MAX matmul matches PyTorch matmul!")
print(f" Input shapes: A={a.shape}, B={b.shape}")
print(f" Output shape: {result_max.shape}")
else:
max_diff = torch.max(torch.abs(result_max - result_torch)).item()
diff = torch.abs(result_max - result_torch)
print("✗ Results do not match!")
print(f" Max absolute difference: {max_diff:.2e}")
print(f" Mean absolute difference: {diff.mean().item():.2e}")
print(f" Std of differences: {diff.std().item():.2e}")