File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -98,9 +98,32 @@ def __init__(
9898 self .norm_eps = norm_eps
9999 self .reg_eps = reg_eps
100100 self .solver = solver
101+ # Cache for storing computed weights
102+ self ._cache = {}
101103
102104 def forward (self , matrix : Tensor ) -> Tensor :
103- U = torch .diag (self .weighting (matrix ))
105+ # Convert matrix to tuple for hashing
106+ with torch .no_grad (): # No need to track gradients for caching
107+ matrix_key = hash (matrix .cpu ().numpy ().tobytes ())
108+
109+ # Check if we have cached result
110+ if matrix_key in self ._cache :
111+ return self ._cache [matrix_key ]
112+
113+ # Compute weights once and reuse
114+
115+ # Original computation optimized
116+
117+ # Move computations to same device as input
118+ U = torch .zeros ([matrix .shape [0 ], matrix .shape [0 ]])
119+
120+ # Compute G and W in a single batch operation if possible
104121 G = _compute_regularized_normalized_gramian (matrix , self .norm_eps , self .reg_eps )
105122 W = _project_weights (U , G , self .solver )
106- return torch .sum (W , dim = 0 )
123+
124+ # Use more efficient sum
125+ result = W .sum (dim = 0 )
126+
127+ # Cache the result
128+ self ._cache [matrix_key ] = result
129+ return result
You can’t perform that action at this time.
0 commit comments