Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions bergson/hessians/eigenvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

from bergson.collector.collector import HookCollectorBase
from bergson.hessians.sharded_computation import ShardedMul
from bergson.hessians.sharded_computation import ShardedMul, shard_bounds
from bergson.utils.logger import get_logger
from bergson.utils.utils import (
assert_type,
Expand Down Expand Up @@ -115,6 +115,11 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
name = assert_type(str, module._name)
# a shape: [N, S, I]

# Augment with a ones column to match the [I+1, I+1] activation
# covariance eigenvectors computed when the bias gradient is collected.
if module._collect_bias:
a = torch.cat([a, a.new_ones(*a.shape[:-1], 1)], dim=-1) # [N, S, I+1]

# Transform: a @ eigen_a
transformed = self.shard_computer._matmul(
vector_nsa=a, matrix_cb=self.eigen_a[name]
Expand Down Expand Up @@ -147,9 +152,9 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
dist.all_reduce(transformed_grad_shard, op=dist.ReduceOp.SUM)

# Extract our shard
shard_size = transformed_grad_shard.shape[0] // self.world_size
start_row = self.rank * shard_size
end_row = (self.rank + 1) * shard_size
start_row, end_row = self.shard_computer.shard_bounds(
transformed_grad_shard.shape[0]
)

# Accumulate (with CPU offloading for memory efficiency)
if name not in self.eigenvalue_corrections:
Expand Down Expand Up @@ -236,7 +241,7 @@ def compute_eigendecomposition(
total_processed: Number of samples used to compute covariance.

Returns:
Per-key eigenvalue shards (each `[m/world_size]`) on CPU. The
Per-key eigenvalue shards (rows per shard_bounds) on CPU. The
eigenvectors are written to disk; the eigenvalues are returned so
callers (e.g. `save_uncorrected_eigenvalues`) can use them without
reloading.
Expand Down Expand Up @@ -371,14 +376,27 @@ def save_uncorrected_eigenvalues(
eigenvalue_a_shard = eigenvalues_a[key].to(device)

if world_size > 1:
# Shards may be uneven, so sum the shard sizes to get the full dimension
# then broadcast each rank's shard into place.
full_dim = torch.tensor(eigenvalue_a_shard.shape[0], device=device)
dist.all_reduce(full_dim, op=dist.ReduceOp.SUM)
m = int(full_dim.item())

eigenvalue_a_full = torch.empty(
eigenvalue_a_shard.shape[0] * world_size,
device=device,
dtype=eigenvalue_a_shard.dtype,
)
dist.all_gather_into_tensor(
eigenvalue_a_full, eigenvalue_a_shard.contiguous()
m, device=device, dtype=eigenvalue_a_shard.dtype
)
for rank_index in range(world_size):
start_row, end_row = shard_bounds(m, rank_index, world_size)
if rank_index == rank:
shard = eigenvalue_a_shard.contiguous()
else:
shard = torch.empty(
end_row - start_row,
device=device,
dtype=eigenvalue_a_shard.dtype,
)
dist.broadcast(shard, src=rank_index)
eigenvalue_a_full[start_row:end_row] = shard
else:
eigenvalue_a_full = eigenvalue_a_shard

Expand Down Expand Up @@ -418,9 +436,8 @@ def _gather_and_shard_along_dim_0(

dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

m = full_shape[0]
shard_size = m // world_size
shard = tensor[rank * shard_size : (rank + 1) * shard_size].contiguous()
start_row, end_row = shard_bounds(full_shape[0], rank, world_size)
shard = tensor[start_row:end_row].contiguous()
result_dict[key] = shard.to(device="cpu")

del tensor
Expand Down
2 changes: 2 additions & 0 deletions bergson/hessians/hessian_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bergson.config.config import AttentionConfig, HessianConfig, IndexConfig
from bergson.data import allocate_batches
from bergson.distributed import init_dist, launch_distributed_run
from bergson.gradients import GradientProcessor
from bergson.hessians.eigenvectors import (
LambdaCollector,
compute_eigendecomposition,
Expand Down Expand Up @@ -205,6 +206,7 @@ def collect_hessians(
"attention_cfgs": attention_cfgs or {},
"path": str(index_cfg.partial_run_path),
"filter_modules": index_cfg.filter_modules,
"processor": GradientProcessor(include_bias=index_cfg.include_bias),
}
desc = f"Approximating Hessians with {hessian_cfg.method}"
if ev_correction:
Expand Down
13 changes: 9 additions & 4 deletions bergson/hessians/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
# a: [N, S, I], valid_masks: [N, S] -> select valid positions
a_bi = a[mask].to(self.dtype) # [num_valid, I]

# Augment with a ones column so A matches the [O, I+1] gradient layout
# produced when the bias gradient is collected.
if module._collect_bias:
a_bi = torch.cat(
[a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1
) # [num_valid, I+1]

# Compute local covariance
local_update_ii = a_bi.mT @ a_bi

Expand All @@ -58,8 +65,7 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM)

# Extract our shard
start_row = self.rank * A_cov_ki.shape[0]
end_row = (self.rank + 1) * A_cov_ki.shape[0]
start_row, end_row = self.shard_computer.shard_bounds(local_update_ii.shape[0])
update_slice_ki = local_update_ii[start_row:end_row, :]

# Accumulate
Expand All @@ -82,8 +88,7 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
dist.all_reduce(local_update_oo, op=dist.ReduceOp.SUM)

# Extract our shard
start_row = self.rank * S_cov_po.shape[0]
end_row = (self.rank + 1) * S_cov_po.shape[0]
start_row, end_row = self.shard_computer.shard_bounds(local_update_oo.shape[0])
update_slice_po = local_update_oo[start_row:end_row, :]

# Accumulate
Expand Down
2 changes: 1 addition & 1 deletion bergson/hessians/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _validate(cfg: IndexConfig):
hessian_index_cfg.run_path = f"{hessian_path}/{method}"
_validate(hessian_index_cfg)

approximate_hessians(hessian_index_cfg, hessian_cfg)
approximate_hessians(hessian_index_cfg, hessian_cfg)

# ── Step 3: Apply inverse Hessian to the mean query gradient ──────────
print(f"Step 3/4: Applying {method} inverse Hessian to mean query gradient...")
Expand Down
27 changes: 19 additions & 8 deletions bergson/hessians/shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None:
# a: [N, S, I], valid_masks: [N, S] -> select valid positions
a_bi = a[mask] # [num_valid, I]

# Augment with a ones column so the [O, I+1] per-batch gradient matches
# the layout produced when the bias gradient is collected.
if module._collect_bias:
a_bi = torch.cat(
[a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1
) # [num_valid, I+1]

module._inputs = a_bi

def backward_hook(self, module: nn.Module, g: Tensor) -> None:
Expand All @@ -74,12 +81,14 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None:
dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM)

# Extract our shard
start_row_grad = self.rank * S_shampoo_po.shape[0]
end_row_grad = (self.rank + 1) * S_shampoo_po.shape[0]
start_row_grad, end_row_grad = self.shard_computer.shard_bounds(
local_update_oo.shape[0]
)
update_slice_po = local_update_oo[start_row_grad:end_row_grad, :]

start_row_act = self.rank * A_shampoo_ki.shape[0]
end_row_act = (self.rank + 1) * A_shampoo_ki.shape[0]
start_row_act, end_row_act = self.shard_computer.shard_bounds(
local_update_ii.shape[0]
)
update_slice_ki = local_update_ii[start_row_act:end_row_act, :]

# Accumulate
Expand All @@ -100,11 +109,13 @@ def teardown(self) -> None:

# Normalize activation covariance by trace
for name, A_shampoo_ki in self.A_shampoo_dict.items():
rows_per_rank = A_shampoo_ki.shape[0]
# Extract diagonal elements from this shard
# For row i in shard, the resp. diagonal column is i + rank * rows_per_rank
diag_indices = torch.arange(rows_per_rank, device=A_shampoo_ki.device)
diag_col_indices = diag_indices + self.rank * rows_per_rank
# For row i in shard, the resp. diagonal column is i + shard start
start_row, _ = self.shard_computer.shard_bounds(A_shampoo_ki.shape[1])
diag_indices = torch.arange(
A_shampoo_ki.shape[0], device=A_shampoo_ki.device
)
diag_col_indices = diag_indices + start_row
local_trace = A_shampoo_ki[diag_indices, diag_col_indices].sum()

# All-reduce to get full trace
Expand Down
Loading
Loading