|
2 | 2 | from torch import Tensor |
3 | 3 | from torch.nn import functional as F |
4 | 4 |
|
| 5 | +from ._gramian_utils import compute_gramian |
5 | 6 | from .bases import _WeightedAggregator, _Weighting |
6 | 7 |
|
7 | 8 |
|
@@ -80,16 +81,24 @@ def __init__(self, n_byzantine: int, n_selected: int): |
80 | 81 |
|
81 | 82 | def forward(self, matrix: Tensor) -> Tensor: |
82 | 83 | self._check_matrix_shape(matrix) |
| 84 | + gramian = compute_gramian(matrix) |
| 85 | + return self._compute_from_gramian(gramian) |
83 | 86 |
|
84 | | - distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist") |
85 | | - n_closest = matrix.shape[0] - self.n_byzantine - 2 |
| 87 | + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: |
| 88 | + gradient_norms_squared = torch.diagonal(gramian) |
| 89 | + distances_squared = ( |
| 90 | + gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian |
| 91 | + ) |
| 92 | + distances = torch.sqrt(distances_squared) |
| 93 | + |
| 94 | + n_closest = gramian.shape[0] - self.n_byzantine - 2 |
86 | 95 | smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False) |
87 | 96 | smallest_distances_excluding_self = smallest_distances[:, 1:] |
88 | 97 | scores = smallest_distances_excluding_self.sum(dim=1) |
89 | 98 |
|
90 | 99 | _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False) |
91 | | - one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0]) |
92 | | - weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected |
| 100 | + one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0]) |
| 101 | + weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected |
93 | 102 |
|
94 | 103 | return weights |
95 | 104 |
|
|
0 commit comments