Skip to content

Commit 262f69b

Browse files
committed
Add bug
1 parent bc06c81 commit 262f69b

1 file changed

Lines changed: 25 additions & 2 deletions

File tree

src/torchjd/aggregation/upgrad.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,32 @@ def __init__(
9393
self.norm_eps = norm_eps
9494
self.reg_eps = reg_eps
9595
self.solver = solver
96+
# Cache for storing computed weights
97+
self._cache = {}
9698

9799
def forward(self, matrix: Tensor) -> Tensor:
98-
U = torch.diag(self.weighting(matrix))
100+
# Convert matrix to tuple for hashing
101+
with torch.no_grad(): # No need to track gradients for caching
102+
matrix_key = hash(matrix.cpu().numpy().tobytes())
103+
104+
# Check if we have cached result
105+
if matrix_key in self._cache:
106+
return self._cache[matrix_key]
107+
108+
# Compute weights once and reuse
109+
110+
# Original computation optimized
111+
112+
# Move computations to same device as input
113+
U = torch.zeros([matrix.shape[0], matrix.shape[0]])
114+
115+
# Compute G and W in a single batch operation if possible
99116
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
100117
W = _project_weights(U, G, self.solver)
101-
return torch.sum(W, dim=0)
118+
119+
# Use more efficient sum
120+
result = W.sum(dim=0)
121+
122+
# Cache the result
123+
self._cache[matrix_key] = result
124+
return result

0 commit comments

Comments
 (0)