Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,32 @@ def __init__(
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver = solver
# Cache for storing computed weights
self._cache = {}

def forward(self, matrix: Tensor) -> Tensor:
U = torch.diag(self.weighting(matrix))
# Convert matrix to tuple for hashing
with torch.no_grad(): # No need to track gradients for caching
matrix_key = hash(matrix.cpu().numpy().tobytes())

# Check if we have cached result
if matrix_key in self._cache:
return self._cache[matrix_key]

# Compute weights once and reuse

# Original computation optimized

# Move computations to same device as input
U = torch.zeros([matrix.shape[0], matrix.shape[0]])

# Compute G and W in a single batch operation if possible
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
W = _project_weights(U, G, self.solver)
return torch.sum(W, dim=0)

# Use more efficient sum
result = W.sum(dim=0)

# Cache the result
self._cache[matrix_key] = result
return result
Loading