Skip to content

Commit d5a82b7

Browse files
committed
Fix type error in compute_gramian_with_autograd
1 parent 0274fef commit d5a82b7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/utils/forward_backwards.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
139139

140140
jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output)))
141141
jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians]
142-
gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices])
142+
products = [jacobian @ jacobian.T for jacobian in jacobian_matrices]
143+
gramian = torch.stack(products).sum(dim=0)
143144

144-
return gramian
145+
return PSDTensor(gramian)
145146

146147

147148
class CloneParams:

0 commit comments

Comments
 (0)