-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgrad_probe_patterns.py
More file actions
116 lines (93 loc) · 3.53 KB
/
grad_probe_patterns.py
File metadata and controls
116 lines (93 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations
import torch
from torch_cudagraph_debug.tensor_debug import CudaGraphTensorProbe, TensorRecord
class DebugMLP(torch.nn.Module):
def __init__(
self,
*,
activation_value_probe: CudaGraphTensorProbe,
activation_grad_probe: CudaGraphTensorProbe,
) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(4, 3, bias=False)
self.fc2 = torch.nn.Linear(3, 2, bias=False)
self.activation_value_probe = activation_value_probe
self.activation_grad_probe = activation_grad_probe
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden = self.fc1(x)
# Pattern 1: probe a forward activation value.
hidden = self.activation_value_probe(hidden)
# Pattern 2: probe the activation gradient when backward reaches hidden.
hidden = self.activation_grad_probe.attach_grad(hidden)
return self.fc2(torch.relu(hidden)).sum()
def main() -> None:
if not torch.cuda.is_available():
raise RuntimeError("This example requires CUDA.")
torch.manual_seed(1234)
device = torch.device("cuda")
static_x = torch.randn(2, 4, device=device)
activation_value_probe = CudaGraphTensorProbe(
"activation.value",
[TensorRecord()],
)
activation_grad_probe = CudaGraphTensorProbe(
"activation.grad",
[TensorRecord()],
)
weight_grad_probe = CudaGraphTensorProbe(
"fc1.weight.grad.hook",
[TensorRecord()],
)
final_weight_grad_probe = CudaGraphTensorProbe(
"fc1.weight.grad.final",
[TensorRecord()],
)
model = DebugMLP(
activation_value_probe=activation_value_probe,
activation_grad_probe=activation_grad_probe,
).to(device)
# Pattern 3: probe a parameter gradient when autograd produces it.
weight_grad_probe.attach_grad(model.fc1.weight)
capture_stream = torch.cuda.Stream()
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
for _ in range(3):
model.zero_grad(set_to_none=True)
loss = model(static_x)
loss.backward()
del loss
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
assert activation_value_probe.records() == []
assert activation_grad_probe.records() == []
assert weight_grad_probe.records() == []
assert final_weight_grad_probe.records() == []
graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(capture_stream):
model.zero_grad(set_to_none=True)
with torch.cuda.graph(graph):
loss = model(static_x)
loss.backward()
# Pattern 4: probe the final parameter .grad buffer after backward.
if model.fc1.weight.grad is None:
raise RuntimeError("fc1.weight.grad should exist after backward")
final_weight_grad_probe(model.fc1.weight.grad)
torch.cuda.current_stream().wait_stream(capture_stream)
for _ in range(2):
graph.replay()
torch.cuda.synchronize()
probes = [
activation_value_probe,
activation_grad_probe,
weight_grad_probe,
final_weight_grad_probe,
]
for probe in probes:
records = probe.records()
assert records, f"{probe.name} did not record any graph replay snapshots"
last = records[-1]
print(f"{last.probe_name}: replay={last.replay_index} shape={last.shape}")
for probe in probes:
probe.close()
if __name__ == "__main__":
main()