We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 0274fef commit d5a82b7Copy full SHA for d5a82b7
tests/utils/forward_backwards.py
@@ -139,9 +139,10 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
139
140
jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output)))
141
jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians]
142
- gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices])
+ products = [jacobian @ jacobian.T for jacobian in jacobian_matrices]
143
+ gramian = torch.stack(products).sum(dim=0)
144
- return gramian
145
+ return PSDTensor(gramian)
146
147
148
class CloneParams:
0 commit comments