Skip to content

Commit 0a9fd42

Browse files
committed
Make dependence on gramian explicit in Krum
1 parent 6afcc92 commit 0a9fd42

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/torchjd/aggregation/krum.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,24 @@ def __init__(self, n_byzantine: int, n_selected: int):
8080

8181
def forward(self, matrix: Tensor) -> Tensor:
8282
self._check_matrix_shape(matrix)
83+
gramian = matrix @ matrix.T
84+
return self._compute_from_gramian(gramian)
8385

84-
distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist")
85-
n_closest = matrix.shape[0] - self.n_byzantine - 2
86+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
87+
gradient_norms_squared = torch.diagonal(gramian)
88+
distances_squared = (
89+
gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
90+
)
91+
distances = torch.sqrt(distances_squared)
92+
93+
n_closest = gramian.shape[0] - self.n_byzantine - 2
8694
smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
8795
smallest_distances_excluding_self = smallest_distances[:, 1:]
8896
scores = smallest_distances_excluding_self.sum(dim=1)
8997

9098
_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
91-
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0])
92-
weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected
99+
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
100+
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
93101

94102
return weights
95103

0 commit comments

Comments
 (0)