@@ -80,16 +80,24 @@ def __init__(self, n_byzantine: int, n_selected: int):
8080
8181 def forward (self , matrix : Tensor ) -> Tensor :
8282 self ._check_matrix_shape (matrix )
83+ gramian = matrix @ matrix .T
84+ return self ._compute_from_gramian (gramian )
8385
84- distances = torch .cdist (matrix , matrix , compute_mode = "donot_use_mm_for_euclid_dist" )
85- n_closest = matrix .shape [0 ] - self .n_byzantine - 2
86+ def _compute_from_gramian (self , gramian : Tensor ) -> Tensor :
87+ gradient_norms_squared = torch .diagonal (gramian )
88+ distances_squared = (
89+ gradient_norms_squared .unsqueeze (0 ) + gradient_norms_squared .unsqueeze (1 ) - 2 * gramian
90+ )
91+ distances = torch .sqrt (distances_squared )
92+
93+ n_closest = gramian .shape [0 ] - self .n_byzantine - 2
8694 smallest_distances , _ = torch .topk (distances , k = n_closest + 1 , largest = False )
8795 smallest_distances_excluding_self = smallest_distances [:, 1 :]
8896 scores = smallest_distances_excluding_self .sum (dim = 1 )
8997
9098 _ , selected_indices = torch .topk (scores , k = self .n_selected , largest = False )
91- one_hot_selected_indices = F .one_hot (selected_indices , num_classes = matrix .shape [0 ])
92- weights = one_hot_selected_indices .sum (dim = 0 ).to (dtype = matrix .dtype ) / self .n_selected
99+ one_hot_selected_indices = F .one_hot (selected_indices , num_classes = gramian .shape [0 ])
100+ weights = one_hot_selected_indices .sum (dim = 0 ).to (dtype = gramian .dtype ) / self .n_selected
93101
94102 return weights
95103
0 commit comments