|
11 | 11 | from tqdm import tqdm |
12 | 12 |
|
13 | 13 | from bergson.collector.collector import HookCollectorBase |
14 | | -from bergson.hessians.sharded_computation import ShardedMul |
| 14 | +from bergson.hessians.sharded_computation import ShardedMul, shard_bounds |
15 | 15 | from bergson.utils.logger import get_logger |
16 | 16 | from bergson.utils.utils import ( |
17 | 17 | assert_type, |
@@ -115,6 +115,11 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: |
115 | 115 | name = assert_type(str, module._name) |
116 | 116 | # a shape: [N, S, I] |
117 | 117 |
|
| 118 | + # Augment with a ones column to match the [I+1, I+1] activation |
| 119 | + # covariance eigenvectors computed when the bias gradient is collected. |
| 120 | + if module._collect_bias: |
| 121 | + a = torch.cat([a, a.new_ones(*a.shape[:-1], 1)], dim=-1) # [N, S, I+1] |
| 122 | + |
118 | 123 | # Transform: a @ eigen_a |
119 | 124 | transformed = self.shard_computer._matmul( |
120 | 125 | vector_nsa=a, matrix_cb=self.eigen_a[name] |
@@ -147,9 +152,9 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: |
147 | 152 | dist.all_reduce(transformed_grad_shard, op=dist.ReduceOp.SUM) |
148 | 153 |
|
149 | 154 | # Extract our shard |
150 | | - shard_size = transformed_grad_shard.shape[0] // self.world_size |
151 | | - start_row = self.rank * shard_size |
152 | | - end_row = (self.rank + 1) * shard_size |
| 155 | + start_row, end_row = self.shard_computer.shard_bounds( |
| 156 | + transformed_grad_shard.shape[0] |
| 157 | + ) |
153 | 158 |
|
154 | 159 | # Accumulate (with CPU offloading for memory efficiency) |
155 | 160 | if name not in self.eigenvalue_corrections: |
@@ -236,7 +241,7 @@ def compute_eigendecomposition( |
236 | 241 | total_processed: Number of samples used to compute covariance. |
237 | 242 |
|
238 | 243 | Returns: |
239 | | - Per-key eigenvalue shards (each `[m/world_size]`) on CPU. The |
| 244 | + Per-key eigenvalue shards (rows per shard_bounds) on CPU. The |
240 | 245 | eigenvectors are written to disk; the eigenvalues are returned so |
241 | 246 | callers (e.g. `save_uncorrected_eigenvalues`) can use them without |
242 | 247 | reloading. |
@@ -371,14 +376,27 @@ def save_uncorrected_eigenvalues( |
371 | 376 | eigenvalue_a_shard = eigenvalues_a[key].to(device) |
372 | 377 |
|
373 | 378 | if world_size > 1: |
| 379 | + # Shards may be uneven, so sum the shard sizes to get the full dimension |
| 380 | + # then broadcast each rank's shard into place. |
| 381 | + full_dim = torch.tensor(eigenvalue_a_shard.shape[0], device=device) |
| 382 | + dist.all_reduce(full_dim, op=dist.ReduceOp.SUM) |
| 383 | + m = int(full_dim.item()) |
| 384 | + |
374 | 385 | eigenvalue_a_full = torch.empty( |
375 | | - eigenvalue_a_shard.shape[0] * world_size, |
376 | | - device=device, |
377 | | - dtype=eigenvalue_a_shard.dtype, |
378 | | - ) |
379 | | - dist.all_gather_into_tensor( |
380 | | - eigenvalue_a_full, eigenvalue_a_shard.contiguous() |
| 386 | + m, device=device, dtype=eigenvalue_a_shard.dtype |
381 | 387 | ) |
| 388 | + for rank_index in range(world_size): |
| 389 | + start_row, end_row = shard_bounds(m, rank_index, world_size) |
| 390 | + if rank_index == rank: |
| 391 | + shard = eigenvalue_a_shard.contiguous() |
| 392 | + else: |
| 393 | + shard = torch.empty( |
| 394 | + end_row - start_row, |
| 395 | + device=device, |
| 396 | + dtype=eigenvalue_a_shard.dtype, |
| 397 | + ) |
| 398 | + dist.broadcast(shard, src=rank_index) |
| 399 | + eigenvalue_a_full[start_row:end_row] = shard |
382 | 400 | else: |
383 | 401 | eigenvalue_a_full = eigenvalue_a_shard |
384 | 402 |
|
@@ -418,9 +436,8 @@ def _gather_and_shard_along_dim_0( |
418 | 436 |
|
419 | 437 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) |
420 | 438 |
|
421 | | - m = full_shape[0] |
422 | | - shard_size = m // world_size |
423 | | - shard = tensor[rank * shard_size : (rank + 1) * shard_size].contiguous() |
| 439 | + start_row, end_row = shard_bounds(full_shape[0], rank, world_size) |
| 440 | + shard = tensor[start_row:end_row].contiguous() |
424 | 441 | result_dict[key] = shard.to(device="cpu") |
425 | 442 |
|
426 | 443 | del tensor |
|
0 commit comments