Skip to content

Fix K-FAC covariance shapes when include_bias=True#294

Open
luciaquirke wants to merge 1 commit into
mainfrom
fix/kfac-include-bias
Open

Fix K-FAC covariance shapes when include_bias=True#294
luciaquirke wants to merge 1 commit into
mainfrom
fix/kfac-include-bias

Conversation

@luciaquirke
Copy link
Copy Markdown
Collaborator

@luciaquirke luciaquirke commented Jun 5, 2026

Closes #277.

bergson build with include_bias=True stores per-layer gradients of shape [O, I+1] (the bias gradient is appended as an extra "activation" column), but the K-FAC covariance path computed A = aᵀa on the raw activation, giving A: [I, I]. apply_hessian's per-layer view(-1, O, I) then failed on the flat N·O·(I+1) gradients:

RuntimeError: shape '[-1, 128, 128]' is invalid for input of size 16512

reproduce with bergson ekfac --include_bias true --method kfac --model EleutherAI/pythia-14m.

Fix

  • collect_hessians now passes GradientProcessor(include_bias=index_cfg.include_bias) to the Hessian collectors — previously the collect_bias flags in target_info were always False during Hessian fitting regardless of the build config.
  • _init_covariance_dict sizes the activation covariance [I+1, I+1] when collect_bias.
  • CovarianceCollector (kfac), TraceCovarianceCollector (tkfac), and ShampooCollector augment the activation with a ones column, matching the build-time gradient layout.
  • LambdaCollector augments the activation before rotating by eigen_a (now [I+1, I+1]), so EKFAC eigenvalue corrections come out [O, I+1].
  • Sharded computation now supports unevenly divisible dimensions via shard_bounds (rank 0 absorbs the remainder) (in this PR because with the bias column, I+1 is almost never divisible by the world size.)
  • Fixed step2_fit_hessian timing bug (approximate_hessians ran outside its _timed block, always reporting 0.0s).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@luciaquirke luciaquirke requested a review from LouisYRYJ June 5, 2026 05:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

K-FAC Hessian + include_bias=True produces incompatible gradient/covariance shapes

1 participant