diff --git a/bergson/hessians/eigenvectors.py b/bergson/hessians/eigenvectors.py index 6531be5c..e0fab999 100644 --- a/bergson/hessians/eigenvectors.py +++ b/bergson/hessians/eigenvectors.py @@ -11,7 +11,7 @@ from tqdm import tqdm from bergson.collector.collector import HookCollectorBase -from bergson.hessians.sharded_computation import ShardedMul +from bergson.hessians.sharded_computation import ShardedMul, shard_bounds from bergson.utils.logger import get_logger from bergson.utils.utils import ( assert_type, @@ -115,6 +115,11 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: name = assert_type(str, module._name) # a shape: [N, S, I] + # Augment with a ones column to match the [I+1, I+1] activation + # covariance eigenvectors computed when the bias gradient is collected. + if module._collect_bias: + a = torch.cat([a, a.new_ones(*a.shape[:-1], 1)], dim=-1) # [N, S, I+1] + # Transform: a @ eigen_a transformed = self.shard_computer._matmul( vector_nsa=a, matrix_cb=self.eigen_a[name] @@ -147,9 +152,9 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: dist.all_reduce(transformed_grad_shard, op=dist.ReduceOp.SUM) # Extract our shard - shard_size = transformed_grad_shard.shape[0] // self.world_size - start_row = self.rank * shard_size - end_row = (self.rank + 1) * shard_size + start_row, end_row = self.shard_computer.shard_bounds( + transformed_grad_shard.shape[0] + ) # Accumulate (with CPU offloading for memory efficiency) if name not in self.eigenvalue_corrections: @@ -236,7 +241,7 @@ def compute_eigendecomposition( total_processed: Number of samples used to compute covariance. Returns: - Per-key eigenvalue shards (each `[m/world_size]`) on CPU. The + Per-key eigenvalue shards (rows per shard_bounds) on CPU. The eigenvectors are written to disk; the eigenvalues are returned so callers (e.g. `save_uncorrected_eigenvalues`) can use them without reloading. @@ -371,14 +376,27 @@ def save_uncorrected_eigenvalues( eigenvalue_a_shard = eigenvalues_a[key].to(device) if world_size > 1: + # Shards may be uneven, so sum the shard sizes to get the full dimension + # then broadcast each rank's shard into place. + full_dim = torch.tensor(eigenvalue_a_shard.shape[0], device=device) + dist.all_reduce(full_dim, op=dist.ReduceOp.SUM) + m = int(full_dim.item()) + eigenvalue_a_full = torch.empty( - eigenvalue_a_shard.shape[0] * world_size, - device=device, - dtype=eigenvalue_a_shard.dtype, - ) - dist.all_gather_into_tensor( - eigenvalue_a_full, eigenvalue_a_shard.contiguous() + m, device=device, dtype=eigenvalue_a_shard.dtype ) + for rank_index in range(world_size): + start_row, end_row = shard_bounds(m, rank_index, world_size) + if rank_index == rank: + shard = eigenvalue_a_shard.contiguous() + else: + shard = torch.empty( + end_row - start_row, + device=device, + dtype=eigenvalue_a_shard.dtype, + ) + dist.broadcast(shard, src=rank_index) + eigenvalue_a_full[start_row:end_row] = shard else: eigenvalue_a_full = eigenvalue_a_shard @@ -418,9 +436,8 @@ def _gather_and_shard_along_dim_0( dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - m = full_shape[0] - shard_size = m // world_size - shard = tensor[rank * shard_size : (rank + 1) * shard_size].contiguous() + start_row, end_row = shard_bounds(full_shape[0], rank, world_size) + shard = tensor[start_row:end_row].contiguous() result_dict[key] = shard.to(device="cpu") del tensor diff --git a/bergson/hessians/hessian_approximations.py b/bergson/hessians/hessian_approximations.py index 230c2e0d..c3bb364e 100644 --- a/bergson/hessians/hessian_approximations.py +++ b/bergson/hessians/hessian_approximations.py @@ -13,6 +13,7 @@ from bergson.config.config import AttentionConfig, HessianConfig, IndexConfig from bergson.data import allocate_batches from bergson.distributed import init_dist, launch_distributed_run +from bergson.gradients import GradientProcessor from bergson.hessians.eigenvectors import ( LambdaCollector, compute_eigendecomposition, @@ -205,6 +206,7 @@ def collect_hessians( "attention_cfgs": attention_cfgs or {}, "path": str(index_cfg.partial_run_path), "filter_modules": index_cfg.filter_modules, + "processor": GradientProcessor(include_bias=index_cfg.include_bias), } desc = f"Approximating Hessians with {hessian_cfg.method}" if ev_correction: diff --git a/bergson/hessians/kfac.py b/bergson/hessians/kfac.py index 9348d715..71d45ce2 100644 --- a/bergson/hessians/kfac.py +++ b/bergson/hessians/kfac.py @@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: # a: [N, S, I], valid_masks: [N, S] -> select valid positions a_bi = a[mask].to(self.dtype) # [num_valid, I] + # Augment with a ones column so A matches the [O, I+1] gradient layout + # produced when the bias gradient is collected. + if module._collect_bias: + a_bi = torch.cat( + [a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1 + ) # [num_valid, I+1] + # Compute local covariance local_update_ii = a_bi.mT @ a_bi @@ -58,8 +65,7 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM) # Extract our shard - start_row = self.rank * A_cov_ki.shape[0] - end_row = (self.rank + 1) * A_cov_ki.shape[0] + start_row, end_row = self.shard_computer.shard_bounds(local_update_ii.shape[0]) update_slice_ki = local_update_ii[start_row:end_row, :] # Accumulate @@ -82,8 +88,7 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: dist.all_reduce(local_update_oo, op=dist.ReduceOp.SUM) # Extract our shard - start_row = self.rank * S_cov_po.shape[0] - end_row = (self.rank + 1) * S_cov_po.shape[0] + start_row, end_row = self.shard_computer.shard_bounds(local_update_oo.shape[0]) update_slice_po = local_update_oo[start_row:end_row, :] # Accumulate diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index cf80167a..da7e6f8d 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -98,7 +98,7 @@ def _validate(cfg: IndexConfig): hessian_index_cfg.run_path = f"{hessian_path}/{method}" _validate(hessian_index_cfg) - approximate_hessians(hessian_index_cfg, hessian_cfg) + approximate_hessians(hessian_index_cfg, hessian_cfg) # ── Step 3: Apply inverse Hessian to the mean query gradient ────────── print(f"Step 3/4: Applying {method} inverse Hessian to mean query gradient...") diff --git a/bergson/hessians/shampoo.py b/bergson/hessians/shampoo.py index 5174fe09..49d0ab11 100644 --- a/bergson/hessians/shampoo.py +++ b/bergson/hessians/shampoo.py @@ -50,6 +50,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: # a: [N, S, I], valid_masks: [N, S] -> select valid positions a_bi = a[mask] # [num_valid, I] + # Augment with a ones column so the [O, I+1] per-batch gradient matches + # the layout produced when the bias gradient is collected. + if module._collect_bias: + a_bi = torch.cat( + [a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1 + ) # [num_valid, I+1] + module._inputs = a_bi def backward_hook(self, module: nn.Module, g: Tensor) -> None: @@ -74,12 +81,14 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM) # Extract our shard - start_row_grad = self.rank * S_shampoo_po.shape[0] - end_row_grad = (self.rank + 1) * S_shampoo_po.shape[0] + start_row_grad, end_row_grad = self.shard_computer.shard_bounds( + local_update_oo.shape[0] + ) update_slice_po = local_update_oo[start_row_grad:end_row_grad, :] - start_row_act = self.rank * A_shampoo_ki.shape[0] - end_row_act = (self.rank + 1) * A_shampoo_ki.shape[0] + start_row_act, end_row_act = self.shard_computer.shard_bounds( + local_update_ii.shape[0] + ) update_slice_ki = local_update_ii[start_row_act:end_row_act, :] # Accumulate @@ -100,11 +109,13 @@ def teardown(self) -> None: # Normalize activation covariance by trace for name, A_shampoo_ki in self.A_shampoo_dict.items(): - rows_per_rank = A_shampoo_ki.shape[0] # Extract diagonal elements from this shard - # For row i in shard, the resp. diagonal column is i + rank * rows_per_rank - diag_indices = torch.arange(rows_per_rank, device=A_shampoo_ki.device) - diag_col_indices = diag_indices + self.rank * rows_per_rank + # For row i in shard, the resp. diagonal column is i + shard start + start_row, _ = self.shard_computer.shard_bounds(A_shampoo_ki.shape[1]) + diag_indices = torch.arange( + A_shampoo_ki.shape[0], device=A_shampoo_ki.device + ) + diag_col_indices = diag_indices + start_row local_trace = A_shampoo_ki[diag_indices, diag_col_indices].sum() # All-reduce to get full trace diff --git a/bergson/hessians/sharded_computation.py b/bergson/hessians/sharded_computation.py index 710ea725..7062c39a 100644 --- a/bergson/hessians/sharded_computation.py +++ b/bergson/hessians/sharded_computation.py @@ -6,6 +6,21 @@ from bergson.utils.utils import get_device +def shard_bounds(dim: int, rank: int, world_size: int) -> tuple[int, int]: + """Range [start, end) of ``rank``'s shard of a dimension of size + ``dim`` split across ``world_size`` ranks. + + Rank 0 takes the remainder rows when ``dim`` is not evenly divisible, + so the shard sizes are [base + remainder, base, ..., base]. + """ + base, remainder = divmod(dim, world_size) + if rank == 0: + return 0, base + remainder + + start = remainder + rank * base + return start, start + base + + class ShardedMul: def __init__( self, @@ -16,6 +31,10 @@ def __init__( self.world_size = dist.get_world_size() if self.dist else 1 self.device = torch.device(get_device(self.rank)) + def shard_bounds(self, dim: int, rank: int | None = None) -> tuple[int, int]: + """Row range [start, end) of ``rank``'s shard (default: this rank).""" + return shard_bounds(dim, self.rank if rank is None else rank, self.world_size) + def _init_covariance_dict( self, activation_covariance_dict: dict, @@ -26,18 +45,19 @@ def _init_covariance_dict( """Initialize the covariance matrices for activations and gradients.""" for name, (device, weight_shape, collect_bias) in target_info.items(): - # Activation covariance A^T A has shape [in_dim, in_dim] - in_dim = weight_shape[1] - shard_in_dim = in_dim if not self.dist else in_dim // self.world_size + # Activation covariance A^T A has shape [in_dim, in_dim], or + # [in + 1, in + 1] when the bias is collected. + in_dim = weight_shape[1] + (1 if collect_bias else 0) + in_start, in_end = self.shard_bounds(in_dim) activation_covariance_dict[name] = torch.zeros( - (shard_in_dim, in_dim), device=self.device, dtype=dtype + (in_end - in_start, in_dim), device=self.device, dtype=dtype ) # Gradient covariance G^T G has shape [out_dim, out_dim] out_dim = weight_shape[0] - shard_out_dim = out_dim if not self.dist else out_dim // self.world_size + out_start, out_end = self.shard_bounds(out_dim) gradient_covariance_dict[name] = torch.zeros( - (shard_out_dim, out_dim), device=self.device, dtype=dtype + (out_end - out_start, out_dim), device=self.device, dtype=dtype ) def _matmul( @@ -47,10 +67,12 @@ def _matmul( ) -> Float[Tensor, "n s b"]: """Vector-matrix multiplication. - If not distributed, this does usual multiplication with a=c. - - If distributed, assumes that c=a/world_size and does sharded multiplication. + - If distributed, assumes that c is this rank's shard of a and + does sharded multiplication. """ - assert vector_nsa.shape[2] == matrix_cb.shape[0] * self.world_size, ( + start, end = self.shard_bounds(vector_nsa.shape[2]) + assert matrix_cb.shape[0] == end - start, ( f"Vector shape {vector_nsa.shape} not compatible with matrix shape " f"{matrix_cb.shape} and world_size {self.world_size}" ) @@ -112,16 +134,20 @@ def _sharded_apply_eigfn( fn, ): """Sharded in-place ``matrix_noi *= fn(λ)`` (function-aware hadamard).""" + o = matrix_noi.shape[1] for rank_index in range(self.world_size): + start_row, end_row = self.shard_bounds(o, rank_index) if rank_index == self.rank: shard_ci = lambda_ci else: - shard_ci = torch.zeros_like(lambda_ci) + shard_ci = torch.zeros( + (end_row - start_row, lambda_ci.shape[1]), + device=lambda_ci.device, + dtype=lambda_ci.dtype, + ) dist.broadcast(shard_ci, src=rank_index) - start_row = rank_index * shard_ci.shape[0] - end_row = (rank_index + 1) * shard_ci.shape[0] matrix_noi[:, start_row:end_row, :].mul_(fn(shard_ci)) if self.rank != rank_index: @@ -134,16 +160,13 @@ def _sharded_matmul( ) -> Float[Tensor, "n s b"]: """ Sharded matrix multiplication for distributed training. - Assumes that c=a/world_size. + Assumes that c is this rank's shard of a (see shard_bounds). vector: [n, s, a] - matrix_shard: [a/world_size, b] + matrix_shard: [c, b] Returns: [n, s, b] """ - # Split the vector into shards - vector_shards_wnsc = torch.chunk( - vector_nsa, self.world_size, dim=-1 - ) # (w, n, s, a/w) - n, s, b = vector_nsa.shape[0], vector_nsa.shape[1], matrix_cb.shape[1] + n, s, a = vector_nsa.shape + b = matrix_cb.shape[1] result_nsb = torch.zeros( (n, s, b), @@ -152,15 +175,20 @@ def _sharded_matmul( ) for rank_index in range(self.world_size): + start_row, end_row = self.shard_bounds(a, rank_index) if rank_index == self.rank: shard_cb = matrix_cb else: - shard_cb = torch.zeros_like(matrix_cb) + shard_cb = torch.zeros( + (end_row - start_row, b), + device=matrix_cb.device, + dtype=matrix_cb.dtype, + ) dist.broadcast(shard_cb, src=rank_index) result_nsb += torch.einsum( "n s c, c b-> n s b", - vector_shards_wnsc[rank_index].to(shard_cb.dtype), + vector_nsa[..., start_row:end_row].to(shard_cb.dtype), shard_cb, ) # [B, c] if self.rank != rank_index: @@ -177,25 +205,30 @@ def _sharded_hadamard( """ Sharded in-place element-wise multiplication for distributed training. gradients: [n, o, i] - matrix_shard: [c, i] where c=o/world_size + matrix_shard: [c, i] where c is this rank's shard of o (see + shard_bounds) """ + o = matrix_noi.shape[1] - global_lambda_mean = lambda_ci.mean() - + # Shards may be uneven, so compute the global mean from the global sum + global_lambda_mean = lambda_ci.sum() dist.all_reduce(global_lambda_mean, op=dist.ReduceOp.SUM) - global_lambda_mean /= self.world_size + global_lambda_mean /= o * lambda_ci.shape[1] for rank_index in range(self.world_size): + start_row, end_row = self.shard_bounds(o, rank_index) if rank_index == self.rank: shard_ci = lambda_ci else: - shard_ci = torch.zeros_like(lambda_ci) + shard_ci = torch.zeros( + (end_row - start_row, lambda_ci.shape[1]), + device=lambda_ci.device, + dtype=lambda_ci.dtype, + ) dist.broadcast(shard_ci, src=rank_index) - start_row = rank_index * shard_ci.shape[0] - end_row = (rank_index + 1) * shard_ci.shape[0] inverse_lambda = ( shard_ci + lambda_damp_factor * global_lambda_mean ).reciprocal() @@ -212,29 +245,31 @@ def _sharded_transpose_matmul( ): """ Sharded matrix multiplication for distributed training. - Assumes that c=i/world_size if left or o/world_size if right. + Assumes that c is this rank's shard of i if left or of o if right + (see shard_bounds). gradients: [n, o, i] matrix_shard: [c, b] where b=i if left or b=o if right - Returns: [n, o, c*w] if left or [n, c*w, i] if right + Returns: [n, o, b] if left or [n, b, i] if right """ - x, y = (matrix_noi.shape[1], matrix_bc.shape[0] * self.world_size) + x, y = (matrix_noi.shape[1], matrix_bc.shape[1]) result_nxy = torch.zeros( matrix_noi.shape[0], x, y, device=matrix_noi.device, dtype=matrix_bc.dtype ) for rank_index in range(self.world_size): + start_row, end_row = self.shard_bounds(y, rank_index) if rank_index == self.rank: shard_bc = matrix_bc else: - shard_bc = torch.zeros_like(matrix_bc) + shard_bc = torch.zeros( + (end_row - start_row, matrix_bc.shape[1]), + device=matrix_bc.device, + dtype=matrix_bc.dtype, + ) dist.broadcast(shard_bc, src=rank_index) - shard_size = shard_bc.shape[0] - start_row = rank_index * shard_size - end_row = (rank_index + 1) * shard_size - result_nxy[:, :, start_row:end_row].copy_( torch.einsum( "n o i, c i -> n o c", matrix_noi.to(shard_bc.dtype), shard_bc diff --git a/bergson/hessians/tkfac.py b/bergson/hessians/tkfac.py index 5a4f6750..18c0d568 100644 --- a/bergson/hessians/tkfac.py +++ b/bergson/hessians/tkfac.py @@ -51,6 +51,13 @@ def forward_hook(self, module: nn.Module, a: Tensor) -> None: # a: [N, S, I], valid_masks: [N, S] -> select valid positions a_bi = a[mask] # [num_valid, I] + # Augment with a ones column so A matches the [O, I+1] gradient layout + # produced when the bias gradient is collected. + if module._collect_bias: + a_bi = torch.cat( + [a_bi, a_bi.new_ones(a_bi.shape[0], 1)], dim=1 + ) # [num_valid, I+1] + module._inputs = a_bi def backward_hook(self, module: nn.Module, g: Tensor) -> None: @@ -82,12 +89,14 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: dist.all_reduce(local_update_ii, op=dist.ReduceOp.SUM) # Extract our shard - start_row_grad = self.rank * S_tcov_po.shape[0] - end_row_grad = (self.rank + 1) * S_tcov_po.shape[0] + start_row_grad, end_row_grad = self.shard_computer.shard_bounds( + local_update_oo.shape[0] + ) update_slice_po = local_update_oo[start_row_grad:end_row_grad, :] - start_row_act = self.rank * A_tcov_ki.shape[0] - end_row_act = (self.rank + 1) * A_tcov_ki.shape[0] + start_row_act, end_row_act = self.shard_computer.shard_bounds( + local_update_ii.shape[0] + ) update_slice_ki = local_update_ii[start_row_act:end_row_act, :] # Accumulate diff --git a/tests/ekfac_tests/test_kfac_include_bias.py b/tests/ekfac_tests/test_kfac_include_bias.py new file mode 100644 index 00000000..1b8dadfb --- /dev/null +++ b/tests/ekfac_tests/test_kfac_include_bias.py @@ -0,0 +1,139 @@ +"""Regression tests for K-FAC + include_bias shape compatibility. + +See https://github.com/EleutherAI/bergson/issues/277: with include_bias=True, +`bergson build` stores per-layer gradients of shape [O, I+1] (the bias gradient +is an extra "activation" column), but the K-FAC covariance collectors used to +compute A^T A on the raw activation, giving an [I, I] activation covariance +that `apply_hessian`'s `.view(-1, O, I)` could not reconcile with the stored +flat size N*O*(I+1). +""" + +import math +import os + +import pytest +import torch +import torch.nn as nn +from safetensors.torch import save_file + +from bergson.gradients import GradientProcessor +from bergson.hessians.eigenvectors import LambdaCollector +from bergson.hessians.kfac import CovarianceCollector +from bergson.hessians.sharded_computation import shard_bounds +from bergson.utils.utils import get_device + +IN_DIM = 4 +OUT_DIM = 6 + + +class TinyBiasModel(nn.Module): + """Minimal model mixing a biased and an unbiased linear layer.""" + + def __init__(self): + super().__init__() + self.biased = nn.Linear(IN_DIM, OUT_DIM, bias=True) + self.unbiased = nn.Linear(OUT_DIM, IN_DIM, bias=False) + + def forward(self, x): + return self.unbiased(self.biased(x)) + + +def test_covariance_collector_include_bias(tmp_path): + """A_cov must use the augmented activation [a; 1] when bias is collected.""" + device = get_device(0) + model = TinyBiasModel().to(device) + collector = CovarianceCollector( + model=model, + dtype=torch.float32, + path=str(tmp_path), + processor=GradientProcessor(include_bias=True), + ) + + # Augmented [I+1, I+1] for the biased layer, raw [O, O] for the unbiased one + assert collector.A_cov_dict["biased"].shape == (IN_DIM + 1, IN_DIM + 1) + assert collector.S_cov_dict["biased"].shape == (OUT_DIM, OUT_DIM) + assert collector.A_cov_dict["unbiased"].shape == (OUT_DIM, OUT_DIM) + + n, s = 2, 3 + x = torch.randn(n, s, IN_DIM, device=device) + mask = torch.ones(n, s, dtype=torch.bool, device=device) + + with collector.with_batch(mask): + out = model(x) + out.sum().backward() + + # Forward hook accumulated A^T A over the augmented activation + a = x[mask] + a_aug = torch.cat([a, a.new_ones(a.shape[0], 1)], dim=1) + torch.testing.assert_close(collector.A_cov_dict["biased"], a_aug.mT @ a_aug) + # Bias-bias corner counts the number of valid positions + torch.testing.assert_close( + collector.A_cov_dict["biased"][-1, -1], + torch.tensor(float(n * s), device=device), + ) + + # The covariance dims must factorize the stored gradient size [O, I+1], + # so apply_hessian's view(-1, O, I+1) succeeds (issue #277). + grad_shape = collector.shapes()["biased"] + assert collector.S_cov_dict["biased"].shape[1] * collector.A_cov_dict[ + "biased" + ].shape[1] == math.prod(grad_shape) + + +def test_lambda_collector_include_bias(tmp_path): + """LambdaCollector must transform the augmented activation [a; 1].""" + device = get_device(0) + model = TinyBiasModel().to(device) + + # Save identity eigenvectors with the augmented activation dimension + eigen_a = { + "biased": torch.eye(IN_DIM + 1, dtype=torch.float32), + "unbiased": torch.eye(OUT_DIM, dtype=torch.float32), + } + eigen_g = { + "biased": torch.eye(OUT_DIM, dtype=torch.float32), + "unbiased": torch.eye(IN_DIM, dtype=torch.float32), + } + os.makedirs(tmp_path / "eigen_activation_sharded") + os.makedirs(tmp_path / "eigen_gradient_sharded") + save_file(eigen_a, str(tmp_path / "eigen_activation_sharded/shard_0.safetensors")) + save_file(eigen_g, str(tmp_path / "eigen_gradient_sharded/shard_0.safetensors")) + + collector = LambdaCollector( + model=model, + path=str(tmp_path), + processor=GradientProcessor(include_bias=True), + ) + + n, s = 2, 3 + x = torch.randn(n, s, IN_DIM, device=device) + mask = torch.ones(n, s, dtype=torch.bool, device=device) + + with collector.with_batch(mask): + out = model(x) + out.sum().backward() + + # Eigenvalue corrections match the stored [O, I+1] gradient layout + assert collector.eigenvalue_corrections["biased"].shape == (OUT_DIM, IN_DIM + 1) + assert collector.eigenvalue_corrections["unbiased"].shape == (IN_DIM, OUT_DIM) + + +@pytest.mark.parametrize("world_size", [1, 2, 3, 4, 7]) +@pytest.mark.parametrize("dim", [1, 7, 64, 129, 513]) +def test_shard_bounds_partitions_dim(dim, world_size): + """Shards tile [0, dim) contiguously; rank 0 takes the remainder rows.""" + if dim < world_size: + pytest.skip("fewer rows than ranks") + + base, remainder = divmod(dim, world_size) + prev_end = 0 + for rank in range(world_size): + start, end = shard_bounds(dim, rank, world_size) + assert start == prev_end + assert end - start == base + (remainder if rank == 0 else 0) + prev_end = end + assert prev_end == dim + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ekfac_tests/test_uneven_sharding.py b/tests/ekfac_tests/test_uneven_sharding.py new file mode 100644 index 00000000..1770ff23 --- /dev/null +++ b/tests/ekfac_tests/test_uneven_sharding.py @@ -0,0 +1,101 @@ +"""Numerical tests for ShardedMul with unevenly sharded dimensions. + +With include_bias=True the activation dimension becomes I+1, which is +generally not divisible by the world size. Rank 0 takes the remainder rows +(see shard_bounds). These tests check every sharded op against its dense +single-process reference under a 2-process gloo group on CPU. +""" + +import socket + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from bergson.hessians.sharded_computation import ShardedMul, shard_bounds + +WORLD_SIZE = 2 +DIM = 5 # odd on purpose: shards are 3 and 2 rows +N, S, O = 2, 3, 4 + + +def _shard_worker(rank, world_size, port, result_dict): + """Run all sharded ops and store rank 0's results for comparison.""" + try: + dist.init_process_group( + "gloo", + init_method=f"tcp://localhost:{port}", + rank=rank, + world_size=world_size, + ) + sharder = ShardedMul() + + # Same seeded data on every rank + g = torch.Generator().manual_seed(0) + matrix = torch.randn(DIM, DIM, generator=g) + vector = torch.randn(N, S, DIM, generator=g) + grads = torch.randn(N, O, DIM, generator=g) + lambda_full = torch.randn(O, DIM, generator=g).abs() + + start, end = shard_bounds(DIM, rank, world_size) + matrix_shard = matrix[start:end].contiguous() + + results = {} + results["matmul"] = sharder._matmul(vector, matrix_shard) + results["transpose_matmul"] = sharder._transpose_matmul(grads, matrix_shard) + + o_start, o_end = shard_bounds(O, rank, world_size) + lambda_shard = lambda_full[o_start:o_end].contiguous() + + hadamard = grads.clone() + sharder._hadamard(hadamard, lambda_shard, lambda_damp_factor=0.1) + results["hadamard"] = hadamard + + eigfn = grads.clone() + sharder._apply_eigfn(eigfn, lambda_shard, fn=torch.rsqrt) + results["apply_eigfn"] = eigfn + + if rank == 0: + result_dict.update(results) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def test_sharded_ops_match_dense_with_uneven_shards(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + manager = mp.Manager() + result_dict = manager.dict() + mp.spawn( + _shard_worker, + args=(WORLD_SIZE, port, result_dict), + nprocs=WORLD_SIZE, + join=True, + ) + + # Dense references with the same seeded data + g = torch.Generator().manual_seed(0) + matrix = torch.randn(DIM, DIM, generator=g) + vector = torch.randn(N, S, DIM, generator=g) + grads = torch.randn(N, O, DIM, generator=g) + lambda_full = torch.randn(O, DIM, generator=g).abs() + + torch.testing.assert_close(result_dict["matmul"], vector @ matrix) + torch.testing.assert_close(result_dict["transpose_matmul"], grads @ matrix.T) + + inverse_lambda = ( + lambda_full + 0.1 * lambda_full.mean() + ).reciprocal() # _hadamard dense path + torch.testing.assert_close(result_dict["hadamard"], grads * inverse_lambda) + + torch.testing.assert_close( + result_dict["apply_eigfn"], grads * torch.rsqrt(lambda_full) + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])