Skip to content

Commit 2c84c45

Browse files
committed
Revert "Add bug"
This reverts commit 5b767b3.
1 parent 5b767b3 commit 2c84c45

1 file changed

Lines changed: 2 additions & 25 deletions

File tree

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -98,32 +98,9 @@ def __init__(
9898
self.norm_eps = norm_eps
9999
self.reg_eps = reg_eps
100100
self.solver = solver
101-
# Cache for storing computed weights
102-
self._cache = {}
103101

104102
def forward(self, matrix: Tensor) -> Tensor:
105-
# Convert matrix to tuple for hashing
106-
with torch.no_grad(): # No need to track gradients for caching
107-
matrix_key = hash(matrix.cpu().numpy().tobytes())
108-
109-
# Check if we have cached result
110-
if matrix_key in self._cache:
111-
return self._cache[matrix_key]
112-
113-
# Compute weights once and reuse
114-
115-
# Original computation optimized
116-
117-
# Move computations to same device as input
118-
U = torch.zeros([matrix.shape[0], matrix.shape[0]])
119-
120-
# Compute G and W in a single batch operation if possible
103+
U = torch.diag(self.weighting(matrix))
121104
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
122105
W = _project_weights(U, G, self.solver)
123-
124-
# Use more efficient sum
125-
result = W.sum(dim=0)
126-
127-
# Cache the result
128-
self._cache[matrix_key] = result
129-
return result
106+
return torch.sum(W, dim=0)

0 commit comments

Comments
 (0)