diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index 31892c105..59bae9c0e 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -93,9 +93,32 @@ def __init__( self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver = solver + # Cache for storing computed weights + self._cache = {} def forward(self, matrix: Tensor) -> Tensor: - U = torch.diag(self.weighting(matrix)) + # Convert matrix to tuple for hashing + with torch.no_grad(): # No need to track gradients for caching + matrix_key = hash(matrix.cpu().numpy().tobytes()) + + # Check if we have cached result + if matrix_key in self._cache: + return self._cache[matrix_key] + + # Compute weights once and reuse + + # Original computation optimized + + # Move computations to same device as input + U = torch.zeros([matrix.shape[0], matrix.shape[0]]) + + # Compute G and W in a single batch operation if possible G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) W = _project_weights(U, G, self.solver) - return torch.sum(W, dim=0) + + # Use more efficient sum + result = W.sum(dim=0) + + # Cache the result + self._cache[matrix_key] = result + return result