File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -98,32 +98,9 @@ def __init__(
9898 self .norm_eps = norm_eps
9999 self .reg_eps = reg_eps
100100 self .solver = solver
101- # Cache for storing computed weights
102- self ._cache = {}
103101
104102 def forward (self , matrix : Tensor ) -> Tensor :
105- # Convert matrix to tuple for hashing
106- with torch .no_grad (): # No need to track gradients for caching
107- matrix_key = hash (matrix .cpu ().numpy ().tobytes ())
108-
109- # Check if we have cached result
110- if matrix_key in self ._cache :
111- return self ._cache [matrix_key ]
112-
113- # Compute weights once and reuse
114-
115- # Original computation optimized
116-
117- # Move computations to same device as input
118- U = torch .zeros ([matrix .shape [0 ], matrix .shape [0 ]])
119-
120- # Compute G and W in a single batch operation if possible
103+ U = torch .diag (self .weighting (matrix ))
121104 G = _compute_regularized_normalized_gramian (matrix , self .norm_eps , self .reg_eps )
122105 W = _project_weights (U , G , self .solver )
123-
124- # Use more efficient sum
125- result = W .sum (dim = 0 )
126-
127- # Cache the result
128- self ._cache [matrix_key ] = result
129- return result
106+ return torch .sum (W , dim = 0 )
You can’t perform that action at this time.
0 commit comments