|
87 | 87 | # \cdots \cdot |
88 | 88 | # \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}} |
89 | 89 | # |
90 | | -# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png |
91 | | -# :alt: Computational graph after forward pass |
92 | | -# |
93 | | -# Computational graph after forward pass |
94 | | -# |
| 90 | +# .. mermaid:: |
| 91 | +# |
| 92 | +# graph TD |
| 93 | +# |
| 94 | +# x["x<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"] |
| 95 | +# W["W<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 96 | +# b["b<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 97 | +# matmul["x @ W"] |
| 98 | +# z["z = x @ W + b<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 99 | +# relu["y_pred = relu(z)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 100 | +# y["y<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"] |
| 101 | +# loss["loss = mse(y_pred, y)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 102 | +# |
| 103 | +# x --> matmul |
| 104 | +# W --> matmul |
| 105 | +# matmul --> z |
| 106 | +# b --> z |
| 107 | +# z --> relu |
| 108 | +# relu --> loss |
| 109 | +# y --> loss |
95 | 110 | # PyTorch considers a node to be a *leaf* if it is not the result of a |
96 | 111 | # tensor operation with at least one input having ``requires_grad=True`` |
97 | 112 | # (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be |
|
260 | 275 | # convention, this attribute will print ``False`` for any leaf node, even |
261 | 276 | # if it requires its gradient. |
262 | 277 | # |
263 | | -# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png |
264 | | -# :alt: Computational graph after backward pass |
265 | | -# |
266 | | -# Computational graph after backward pass |
267 | | -# |
| 278 | +# .. mermaid:: |
| 279 | +# |
| 280 | +# graph TD |
| 281 | +# |
| 282 | +# x["x<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"] |
| 283 | +# W["W<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=torch.Tensor"] |
| 284 | +# b["b<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=torch.Tensor"] |
| 285 | +# matmul["x @ W"] |
| 286 | +# z["z = x @ W + b<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"] |
| 287 | +# relu["y_pred = relu(z)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"] |
| 288 | +# y["y<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"] |
| 289 | +# loss["loss = mse(y_pred, y)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"] |
| 290 | +# |
| 291 | +# x --> matmul |
| 292 | +# W --> matmul |
| 293 | +# matmul --> z |
| 294 | +# b --> z |
| 295 | +# z --> relu |
| 296 | +# relu --> loss |
| 297 | +# y --> loss |
268 | 298 | # If you call ``retain_grad()`` on a leaf tensor, it results in a no-op |
269 | 299 | # since leaf tensors already retain their gradients by default (when |
270 | 300 | # ``requires_grad=True``). |
|
0 commit comments