Skip to content

Unify scalar optimizers.#210

Merged
skyw merged 4 commits into
mainfrom
skyw/unify_scalar_optimizer
Jun 5, 2026
Merged

Unify scalar optimizers.#210
skyw merged 4 commits into
mainfrom
skyw/unify_scalar_optimizer

Conversation

@skyw

@skyw skyw commented May 28, 2026

Copy link
Copy Markdown
Contributor

Provide unified optimizer class for different scalar update functions.

Keep the detailed docstring in update function and only have very light docstring for the wrapper class.

Signed-off-by: Hao Wu <skyw@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented May 28, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

@greptile

@greptile-apps

greptile-apps Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR consolidates four previously independent scalar optimizer classes (Lion, LaProp, Signum, SimplifiedAdEMAMix) behind a shared _ScalarOptimizerBase in a new base.py, eliminating duplicated step loops, state initialization logic, and validation helpers.

  • _ScalarOptimizerBase handles lazy state allocation, the step counter, update_kwarg_names dispatch, and pre/post_step_inplace hooks; concrete optimizers become constructor-only wrappers that supply an update_fn and their specific defaults.
  • calculate_lion_update gains a step kwarg (immediately deleted) for call-site uniformity; LaProp.frob_normalize is expressed via the new hook pair rather than inline in the step loop.
  • The two per-optimizer test files are deleted and replaced with mixin-based shared suites (_CommonScalarOptimizerTests, _HasBetasTests, _HasEpsTests) that each concrete optimizer test class composes, giving consistent cross-optimizer coverage alongside optimizer-specific tests.

Confidence Score: 5/5

Safe to merge — the refactoring is a faithful mechanical consolidation with no logic changes to any optimizer's update math.

The step loop, state initialization, and validation were mechanically lifted into the shared base, and the diff confirms the behaviour is preserved: same update equation ordering (pre-norm capture → weight decay → update fn → post-norm rescale for frob_normalize), same error messages, same defaults. The only net-new behaviour is that Lion now tracks a step counter it ignores, which is intentional and documented. The mixin-based test consolidation preserves all meaningful assertions from the deleted files and adds cross-optimizer smoke coverage for Signum and SimplifiedAdEMAMix.

No files require special attention.

Important Files Changed

Filename Overview
emerging_optimizers/scalar_optimizers/base.py New file: shared _ScalarOptimizerBase with lazy state init, pre/post_step_inplace hooks, and SingleMomentumOptimizer/TwoMomentsOptimizer subclasses; logic matches what was in the deleted per-optimizer files.
emerging_optimizers/scalar_optimizers/init.py Replaces two thin module imports with all four optimizer class definitions (Lion, Signum, LaProp, SimplifiedAdEMAMix) now expressed as constructor-only wrappers over the shared base; frob-normalize handled via overridden hooks on LaProp.
emerging_optimizers/scalar_optimizers/update_functions/lion.py Added step: int kwarg (immediately deleted) for signature uniformity with other update functions; no logic change.
tests/test_scalar_optimizers.py Old per-optimizer test files deleted; replaced with mixin-based _CommonScalarOptimizerTests, _HasBetasTests, and _HasEpsTests giving each optimizer a consistent shared suite plus optimizer-specific tests for LaProp, Lion, Signum, and SimplifiedAdEMAMix.

Reviews (3): Last reviewed commit: "add one more test and adjust logging lev..." | Re-trigger Greptile

Comment thread emerging_optimizers/scalar_optimizers/__init__.py
Comment thread emerging_optimizers/scalar_optimizers/base.py
@skyw

skyw commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 552d23b

@github-actions

github-actions Bot commented May 28, 2026

Copy link
Copy Markdown

Test Results

   50 files   -  4    120 suites  +4   1m 28s ⏱️ +5s
1 114 tests +24  1 114 ✅ +24  0 💤 ±0  0 ❌ ±0 
2 479 runs  +48  2 479 ✅ +48  0 💤 ±0  0 ❌ ±0 

Results for commit fa638f5. ± Comparison against base commit 3e86f89.

This pull request removes 31 and adds 55 tests. Note that renamed tests count towards both.
__main__.LaPropOptimizerTest ‑ test_init_group_skip_non_grad_params0 (True)
__main__.LaPropOptimizerTest ‑ test_init_group_skip_non_grad_params1 (False)
__main__.LaPropOptimizerTest ‑ test_no_grad_no_update_params_unchanged0 (shape=(3, 3))
__main__.LaPropOptimizerTest ‑ test_no_grad_no_update_params_unchanged1 (shape=(15, 31))
__main__.LaPropOptimizerTest ‑ test_no_grad_no_update_params_unchanged2 (shape=(127, 255))
__main__.LaPropOptimizerTest ‑ test_smoke0 (shape=(3, 3))
__main__.LaPropOptimizerTest ‑ test_smoke1 (shape=(15, 31))
__main__.LaPropOptimizerTest ‑ test_smoke2 (shape=(127, 255))
__main__.LaPropOptimizerTest ‑ test_state_initialization0 (shape=(3, 3))
__main__.LaPropOptimizerTest ‑ test_state_initialization1 (shape=(15, 31))
…
__main__.LaPropOptimizerTest ‑ test_closure_unsupported
__main__.LaPropOptimizerTest ‑ test_frob_normalize_with_nonzero_weight_decay_logs_error
__main__.LaPropOptimizerTest ‑ test_init_group_skip_non_grad_params
__main__.LaPropOptimizerTest ‑ test_init_group_skip_non_grad_params (skip_non_grad_params=False)
__main__.LaPropOptimizerTest ‑ test_init_group_skip_non_grad_params (skip_non_grad_params=True)
__main__.LaPropOptimizerTest ‑ test_no_grad_no_update_params_unchanged
__main__.LaPropOptimizerTest ‑ test_param_groups_large_lr_moves_more
__main__.LaPropOptimizerTest ‑ test_smoke
__main__.LaPropOptimizerTest ‑ test_state_evolves_correctly0 (shape=(3, 3))
__main__.LaPropOptimizerTest ‑ test_state_evolves_correctly1 (shape=(15, 31))
…

♻️ This comment has been updated with latest results.

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from mkhona-nvidia May 28, 2026 19:41
@skyw

skyw commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 753808b

@codecov

codecov Bot commented May 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.21429% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
emerging_optimizers/scalar_optimizers/base.py 96.92% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test fa638f5

state = self.state[p]
if len(state) == 0:
for key in self.state_keys:
state[key] = torch.zeros_like(p.data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add warning that states are always the same shape as p?

many optims like Adafactor do not have that

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consider that is implied by scalar optimizer, everything is element wise so states can only have same shape.
I think a note in the docstring would be enough?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good

Comment thread emerging_optimizers/scalar_optimizers/base.py
@skyw skyw merged commit b995d2e into main Jun 5, 2026
25 checks passed
@skyw skyw deleted the skyw/unify_scalar_optimizer branch June 5, 2026 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants