Skip to content

Commit 5b767b3

Browse files
committed
Add bug
1 parent 480705b commit 5b767b3

1 file changed

Lines changed: 25 additions & 2 deletions

File tree

src/torchjd/aggregation/upgrad.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,32 @@ def __init__(
9898
self.norm_eps = norm_eps
9999
self.reg_eps = reg_eps
100100
self.solver = solver
101+
# Cache for storing computed weights
102+
self._cache = {}
101103

102104
def forward(self, matrix: Tensor) -> Tensor:
103-
U = torch.diag(self.weighting(matrix))
105+
# Convert matrix to tuple for hashing
106+
with torch.no_grad(): # No need to track gradients for caching
107+
matrix_key = hash(matrix.cpu().numpy().tobytes())
108+
109+
# Check if we have cached result
110+
if matrix_key in self._cache:
111+
return self._cache[matrix_key]
112+
113+
# Compute weights once and reuse
114+
115+
# Original computation optimized
116+
117+
# Move computations to same device as input
118+
U = torch.zeros([matrix.shape[0], matrix.shape[0]])
119+
120+
# Compute G and W in a single batch operation if possible
104121
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
105122
W = _project_weights(U, G, self.solver)
106-
return torch.sum(W, dim=0)
123+
124+
# Use more efficient sum
125+
result = W.sum(dim=0)
126+
127+
# Cache the result
128+
self._cache[matrix_key] = result
129+
return result

0 commit comments

Comments
 (0)