From 5b767b32f71ead82420053c6538c2b322c1f2a54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Mar 2025 11:44:45 +0100 Subject: [PATCH 1/2] Add bug --- src/torchjd/aggregation/upgrad.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index fdc83945f..c6c8d62ef 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -98,9 +98,32 @@ def __init__( self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver = solver + # Cache for storing computed weights + self._cache = {} def forward(self, matrix: Tensor) -> Tensor: - U = torch.diag(self.weighting(matrix)) + # Convert matrix to tuple for hashing + with torch.no_grad(): # No need to track gradients for caching + matrix_key = hash(matrix.cpu().numpy().tobytes()) + + # Check if we have cached result + if matrix_key in self._cache: + return self._cache[matrix_key] + + # Compute weights once and reuse + + # Original computation optimized + + # Move computations to same device as input + U = torch.zeros([matrix.shape[0], matrix.shape[0]]) + + # Compute G and W in a single batch operation if possible G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) W = _project_weights(U, G, self.solver) - return torch.sum(W, dim=0) + + # Use more efficient sum + result = W.sum(dim=0) + + # Cache the result + self._cache[matrix_key] = result + return result From 2c84c459c6985cbce971f85550d7f6a320270e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Mar 2025 11:55:09 +0100 Subject: [PATCH 2/2] Revert "Add bug" This reverts commit 5b767b32f71ead82420053c6538c2b322c1f2a54. --- src/torchjd/aggregation/upgrad.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index c6c8d62ef..fdc83945f 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -98,32 +98,9 @@ def __init__( self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver = solver - # Cache for storing computed weights - self._cache = {} def forward(self, matrix: Tensor) -> Tensor: - # Convert matrix to tuple for hashing - with torch.no_grad(): # No need to track gradients for caching - matrix_key = hash(matrix.cpu().numpy().tobytes()) - - # Check if we have cached result - if matrix_key in self._cache: - return self._cache[matrix_key] - - # Compute weights once and reuse - - # Original computation optimized - - # Move computations to same device as input - U = torch.zeros([matrix.shape[0], matrix.shape[0]]) - - # Compute G and W in a single batch operation if possible + U = torch.diag(self.weighting(matrix)) G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) W = _project_weights(U, G, self.solver) - - # Use more efficient sum - result = W.sum(dim=0) - - # Cache the result - self._cache[matrix_key] = result - return result + return torch.sum(W, dim=0)