Skip to content

Commit 9444e59

Browse files
luciaquirkeclaude
andcommitted
Fix K-FAC covariance shapes when include_bias=True (#277)
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 23fc6e1 commit 9444e59

9 files changed

Lines changed: 385 additions & 66 deletions

File tree

bergson/hessians/eigenvectors.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tqdm import tqdm
1212

1313
from bergson.collector.collector import HookCollectorBase
14-
from bergson.hessians.sharded_computation import ShardedMul
14+
from bergson.hessians.sharded_computation import ShardedMul, shard_bounds
1515
from bergson.utils.logger import get_logger
1616
from bergson.utils.utils import (
1717
assert_type,
@@ -115,6 +115,11 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
115115
name = assert_type(str, module._name)
116116
# a shape: [N, S, I]
117117

118+
# Augment with a ones column to match the [I+1, I+1] activation
119+
# covariance eigenvectors computed when the bias gradient is collected.
120+
if module._collect_bias:
121+
a = torch.cat([a, a.new_ones(*a.shape[:-1], 1)], dim=-1) # [N, S, I+1]
122+
118123
# Transform: a @ eigen_a
119124
transformed = self.shard_computer._matmul(
120125
vector_nsa=a, matrix_cb=self.eigen_a[name]
@@ -147,9 +152,9 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
147152
dist.all_reduce(transformed_grad_shard, op=dist.ReduceOp.SUM)
148153

149154
# Extract our shard
150-
shard_size = transformed_grad_shard.shape[0] // self.world_size
151-
start_row = self.rank * shard_size
152-
end_row = (self.rank + 1) * shard_size
155+
start_row, end_row = self.shard_computer.shard_bounds(
156+
transformed_grad_shard.shape[0]
157+
)
153158

154159
# Accumulate (with CPU offloading for memory efficiency)
155160
if name not in self.eigenvalue_corrections:
@@ -236,7 +241,7 @@ def compute_eigendecomposition(
236241
total_processed: Number of samples used to compute covariance.
237242
238243
Returns:
239-
Per-key eigenvalue shards (each `[m/world_size]`) on CPU. The
244+
Per-key eigenvalue shards (rows per shard_bounds) on CPU. The
240245
eigenvectors are written to disk; the eigenvalues are returned so
241246
callers (e.g. `save_uncorrected_eigenvalues`) can use them without
242247
reloading.
@@ -371,14 +376,27 @@ def save_uncorrected_eigenvalues(
371376
eigenvalue_a_shard = eigenvalues_a[key].to(device)
372377

373378
if world_size > 1:
379+
# Shards may be uneven, so sum the shard sizes to get the full dimension
380+
# then broadcast each rank's shard into place.
381+
full_dim = torch.tensor(eigenvalue_a_shard.shape[0], device=device)
382+
dist.all_reduce(full_dim, op=dist.ReduceOp.SUM)
383+
m = int(full_dim.item())
384+
374385
eigenvalue_a_full = torch.empty(
375-
eigenvalue_a_shard.shape[0] * world_size,
376-
device=device,
377-
dtype=eigenvalue_a_shard.dtype,
378-
)
379-
dist.all_gather_into_tensor(
380-
eigenvalue_a_full, eigenvalue_a_shard.contiguous()
386+
m, device=device, dtype=eigenvalue_a_shard.dtype
381387
)
388+
for rank_index in range(world_size):
389+
start_row, end_row = shard_bounds(m, rank_index, world_size)
390+
if rank_index == rank:
391+
shard = eigenvalue_a_shard.contiguous()
392+
else:
393+
shard = torch.empty(
394+
end_row - start_row,
395+
device=device,
396+
dtype=eigenvalue_a_shard.dtype,
397+
)
398+
dist.broadcast(shard, src=rank_index)
399+
eigenvalue_a_full[start_row:end_row] = shard
382400
else:
383401
eigenvalue_a_full = eigenvalue_a_shard
384402

@@ -418,9 +436,8 @@ def _gather_and_shard_along_dim_0(
418436

419437
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
420438

421-
m = full_shape[0]
422-
shard_size = m // world_size
423-
shard = tensor[rank * shard_size : (rank + 1) * shard_size].contiguous()
439+
start_row, end_row = shard_bounds(full_shape[0], rank, world_size)
440+
shard = tensor[start_row:end_row].contiguous()
424441
result_dict[key] = shard.to(device="cpu")
425442

426443
del tensor

bergson/hessians/hessian_approximations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bergson.config.config import AttentionConfig, HessianConfig, IndexConfig
1414
from bergson.data import allocate_batches
1515
from bergson.distributed import init_dist, launch_distributed_run
16+
from bergson.gradients import GradientProcessor
1617
from bergson.hessians.eigenvectors import (
1718
LambdaCollector,
1819
compute_eigendecomposition,
@@ -205,6 +206,7 @@ def collect_hessians(
205206
"attention_cfgs": attention_cfgs or {},
206207
"path": str(index_cfg.partial_run_path),
207208
"filter_modules": index_cfg.filter_modules,
209+
"processor": GradientProcessor(include_bias=index_cfg.include_bias),
208210
}
209211
desc = f"Approximating Hessians with {hessian_cfg.method}"
210212
if ev_correction:

bergson/hessians/kfac.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
5050
# a: [N, S, I], valid_masks: [N, S] -> select valid positions
5151
a_bi = a[mask].to(self.dtype) # [num_valid, I]
5252

53+
# Augment with a ones column so A matches the [O, I+1] gradient layout
54+
# produced when the bias gradient is collected.
55+
if module._collect_bias:
56+
a_bi = torch.cat(
57+
[a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1
58+
) # [num_valid, I+1]
59+
5360
# Compute local covariance
5461
local_update_ii = a_bi.mT @ a_bi
5562

@@ -58,8 +65,7 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
5865
dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM)
5966

6067
# Extract our shard
61-
start_row = self.rank * A_cov_ki.shape[0]
62-
end_row = (self.rank + 1) * A_cov_ki.shape[0]
68+
start_row, end_row = self.shard_computer.shard_bounds(local_update_ii.shape[0])
6369
update_slice_ki = local_update_ii[start_row:end_row, :]
6470

6571
# Accumulate
@@ -82,8 +88,7 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
8288
dist.all_reduce(local_update_oo, op=dist.ReduceOp.SUM)
8389

8490
# Extract our shard
85-
start_row = self.rank * S_cov_po.shape[0]
86-
end_row = (self.rank + 1) * S_cov_po.shape[0]
91+
start_row, end_row = self.shard_computer.shard_bounds(local_update_oo.shape[0])
8792
update_slice_po = local_update_oo[start_row:end_row, :]
8893

8994
# Accumulate

bergson/hessians/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _validate(cfg: IndexConfig):
9898
hessian_index_cfg.run_path = f"{hessian_path}/{method}"
9999
_validate(hessian_index_cfg)
100100

101-
approximate_hessians(hessian_index_cfg, hessian_cfg)
101+
approximate_hessians(hessian_index_cfg, hessian_cfg)
102102

103103
# ── Step 3: Apply inverse Hessian to the mean query gradient ──────────
104104
print(f"Step 3/4: Applying {method} inverse Hessian to mean query gradient...")

bergson/hessians/shampoo.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
5050
# a: [N, S, I], valid_masks: [N, S] -> select valid positions
5151
a_bi = a[mask] # [num_valid, I]
5252

53+
# Augment with a ones column so the [O, I+1] per-batch gradient matches
54+
# the layout produced when the bias gradient is collected.
55+
if module._collect_bias:
56+
a_bi = torch.cat(
57+
[a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1
58+
) # [num_valid, I+1]
59+
5360
module._inputs = a_bi
5461

5562
def backward_hook(self, module: nn.Module, g: Tensor) -> None:
@@ -74,12 +81,14 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
7481
dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM)
7582

7683
# Extract our shard
77-
start_row_grad = self.rank * S_shampoo_po.shape[0]
78-
end_row_grad = (self.rank + 1) * S_shampoo_po.shape[0]
84+
start_row_grad, end_row_grad = self.shard_computer.shard_bounds(
85+
local_update_oo.shape[0]
86+
)
7987
update_slice_po = local_update_oo[start_row_grad:end_row_grad, :]
8088

81-
start_row_act = self.rank * A_shampoo_ki.shape[0]
82-
end_row_act = (self.rank + 1) * A_shampoo_ki.shape[0]
89+
start_row_act, end_row_act = self.shard_computer.shard_bounds(
90+
local_update_ii.shape[0]
91+
)
8392
update_slice_ki = local_update_ii[start_row_act:end_row_act, :]
8493

8594
# Accumulate
@@ -100,11 +109,13 @@ def teardown(self) -> None:
100109

101110
# Normalize activation covariance by trace
102111
for name, A_shampoo_ki in self.A_shampoo_dict.items():
103-
rows_per_rank = A_shampoo_ki.shape[0]
104112
# Extract diagonal elements from this shard
105-
# For row i in shard, the resp. diagonal column is i + rank * rows_per_rank
106-
diag_indices = torch.arange(rows_per_rank, device=A_shampoo_ki.device)
107-
diag_col_indices = diag_indices + self.rank * rows_per_rank
113+
# For row i in shard, the resp. diagonal column is i + shard start
114+
start_row, _ = self.shard_computer.shard_bounds(A_shampoo_ki.shape[1])
115+
diag_indices = torch.arange(
116+
A_shampoo_ki.shape[0], device=A_shampoo_ki.device
117+
)
118+
diag_col_indices = diag_indices + start_row
108119
local_trace = A_shampoo_ki[diag_indices, diag_col_indices].sum()
109120

110121
# All-reduce to get full trace

0 commit comments

Comments
 (0)