File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -93,9 +93,32 @@ def __init__(
9393 self .norm_eps = norm_eps
9494 self .reg_eps = reg_eps
9595 self .solver = solver
96+ # Cache for storing computed weights
97+ self ._cache = {}
9698
9799 def forward (self , matrix : Tensor ) -> Tensor :
98- U = torch .diag (self .weighting (matrix ))
100+ # Convert matrix to tuple for hashing
101+ with torch .no_grad (): # No need to track gradients for caching
102+ matrix_key = hash (matrix .cpu ().numpy ().tobytes ())
103+
104+ # Check if we have cached result
105+ if matrix_key in self ._cache :
106+ return self ._cache [matrix_key ]
107+
108+ # Compute weights once and reuse
109+
110+ # Original computation optimized
111+
112+ # Move computations to same device as input
113+ U = torch .zeros ([matrix .shape [0 ], matrix .shape [0 ]])
114+
115+ # Compute G and W in a single batch operation if possible
99116 G = _compute_regularized_normalized_gramian (matrix , self .norm_eps , self .reg_eps )
100117 W = _project_weights (U , G , self .solver )
101- return torch .sum (W , dim = 0 )
118+
119+ # Use more efficient sum
120+ result = W .sum (dim = 0 )
121+
122+ # Cache the result
123+ self ._cache [matrix_key ] = result
124+ return result
You can’t perform that action at this time.
0 commit comments