Skip to content

[Refactor] Dynamic value-estimator registry across all loss modules#3780

Open
vmoens wants to merge 1 commit into
pytorch:mainfrom
vmoens:feat/value-estimator-registry
Open

[Refactor] Dynamic value-estimator registry across all loss modules#3780
vmoens wants to merge 1 commit into
pytorch:mainfrom
vmoens:feat/value-estimator-registry

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 19, 2026

Context

Today every loss in torchrl.objectives ships its own make_value_estimator body with a hand-coded if/elif chain over ValueEstimators.{TD0,TD1,GAE,TDLambda,VTrace}. The chain knows the class names, the default hyperparameters, and any per-estimator construction quirks (e.g. V-Trace needs the actor). Adding a new estimator therefore means touching ~15 loss files plus their tests.

This PR replaces that with a registry: estimators self-register via a decorator at class-definition time, and every loss dispatches through a single helper.

This is the registry half of #3748 — pulled out into its own PR per review feedback so that the MAPPO/IPPO/MultiAgentGAE/ValueNorm work in #3748 stays scoped to its features, and the registry refactor can be reviewed independently and extended to all losses (the original PR only converted PPO / A2C / Reinforce, leaving the headline benefit unfulfilled for the other ~12 loss classes).

What's new

torchrl.objectives.utils

  • _VALUE_ESTIMATOR_REGISTRY: dict[ValueEstimators, _ValueEstimatorRegistryEntry]
  • @register_value_estimator(value_type, *, default_kwargs=...) — class decorator that estimators apply at definition time.
  • get_value_estimator_entry(value_type) — accepts both ValueEstimators enum members and lowercase string aliases (`"gae"`, `"vtrace"`, ...). Handy for hydra/yaml configs.
  • build_value_estimator(loss, value_type, **hp) — low-level: looks up the class via the registry, merges entry.default_kwargs with the caller's hp, then calls cls.for_loss(loss, **merged).
  • dispatch_value_estimator(loss, value_type, *, supported, tensor_keys, **hp) — high-level helper most losses use. It:
    1. validates value_type against the loss's supported set;
    2. seeds gamma from loss.gamma if present;
    3. builds the estimator via the registry;
    4. assigns loss._value_estimator / loss.value_type;
    5. calls set_keys(**tensor_keys).
  • default_value_kwargs() is now a 1-line shim over the registry (back-compat preserved).

torchrl.objectives.value.advantages

  • ValueEstimatorBase grows a for_loss(cls, loss_module, **hp) classmethod. Default behaviour: pick loss.critic_network, fall back to loss.value_network. If the caller (i.e. the loss) passes value_network=<module> in hp, that wins.
  • VTrace overrides for_loss to also pluck the actor and handle the functional-actor deep-copy.
  • TD0Estimator / TD1Estimator / TDLambdaEstimator / GAE / VTrace each carry an @register_value_estimator(...) decorator with their default kwargs.

Per-loss conversion

Every loss now declares a SUPPORTED_VALUE_ESTIMATORS tuple, and its make_value_estimator body collapses to a single dispatch_value_estimator(...) call:

Loss family Supported set
PPOLoss, A2CLoss, ReinforceLoss TD0, TD1, TDLambda, GAE, VTrace
DQNLoss, IQLLoss, DiscreteIQLLoss, DDPGLoss, TD3Loss, TD3BCLoss, SACLoss, DiscreteSACLoss, CQLLoss, DiscreteCQLLoss, REDQLoss, REDQLoss_deprecated, DreamerActorLoss, DreamerV3ActorLoss, CrossQLoss TD0, TD1, TDLambda
QMixerLoss TD0

The instance/class fast-path for power users who pass a ValueEstimatorBase directly stays at the top of every body.

Why a partial conversion would have been bad

The version in #3748 only converted PPO/A2C/Reinforce. That left the if/elif chains intact in 12 other losses, and the PR still had to hand-edit 8 test files to add new estimators to the "raises NotImplementedError" lists. The registry's whole pitch ("no edits in existing losses when adding a new estimator") only holds if every loss participates. This PR finishes the conversion so the next estimator addition — including MAGAE in #3748 — really does require zero edits to existing losses.

Verification

  • New TestValueEstimatorRegistry suite (8 tests) in test/objectives/test_loss_module.py:
    • all built-ins are registered
    • string alias resolution (\"gae\"ValueEstimators.GAE)
    • error paths for unknown alias / non-enum type
    • third-party estimator registration via the decorator
    • dispatch_value_estimator rejects estimators outside the supported set
    • value_network=... override path
    • default_value_kwargs() matches the registry
  • Existing test suites pass unchanged:
    • pytest test/objectives/test_ppo.py test/objectives/test_dqn.py test/objectives/test_ddpg.py test/objectives/test_sac.py test/objectives/test_iql.py test/objectives/test_cql.py test/objectives/test_dreamer.py — 4521 passed.

Stats

  • ~700 lines deleted, ~600 added across 19 files.
  • All registry indirection is in torchrl/objectives/utils.py (one file, ~120 lines).

Out of scope (for #3748 to add on top)

Made with Cursor

Replaces the hard-coded if/elif dispatch in every loss's
`make_value_estimator` with a registry that estimators self-register
into. Adding a new estimator now requires only the
`@register_value_estimator` decorator on the class — no edits in any
existing loss file.

What changed:

- `torchrl/objectives/utils.py`:
  - new `_VALUE_ESTIMATOR_REGISTRY` dict;
  - `register_value_estimator(value_type, default_kwargs=...)` decorator;
  - `get_value_estimator_entry(value_type)` lookup that accepts both
    `ValueEstimators` enum members and lowercase string aliases (e.g.
    `"gae"`) — handy for hydra/yaml configs;
  - `build_value_estimator(loss, value_type, **hp)` low-level builder;
  - `dispatch_value_estimator(loss, value_type, *, supported,
    tensor_keys, **hp)` — the high-level helper most losses now use.
    It validates `value_type` against the loss's `supported` set,
    builds the estimator via the registry, and applies `set_keys`;
  - `default_value_kwargs()` is now a thin shim that reads from the
    registry (back-compat preserved).

- `torchrl/objectives/value/advantages.py`: `ValueEstimatorBase` grows
  a `for_loss(cls, loss_module, **hp)` classmethod that the registry
  uses to wire each estimator against a loss. Default picks
  `loss.critic_network` (falling back to `loss.value_network`), but
  the loss can pass `value_network=<the module>` through
  `dispatch_value_estimator` to override. `VTrace` overrides
  `for_loss` to also wire the actor and handle the functional-actor
  deep-copy. All five built-in estimators (TD0/TD1/TDLambda/GAE/VTrace)
  carry a `@register_value_estimator` decorator with their default
  kwargs.

- Every loss declares a `SUPPORTED_VALUE_ESTIMATORS` tuple and its
  `make_value_estimator` body collapses to a single
  `dispatch_value_estimator(...)` call:
  - `ppo.py`, `a2c.py`, `reinforce.py`: all five estimators supported.
  - `dqn.py` (DQN + QMixer), `iql.py`, `ddpg.py`, `td3.py`, `td3_bc.py`,
    `sac.py`, `cql.py`, `redq.py`, `dreamer.py`, `dreamer_v3.py`,
    `crossq.py`, `deprecated.py`, `multiagent/qmixer.py`: `{TD0, TD1,
    TDLambda}` (QMixer only TD0).

  The instance/class fast-path for power users who pass a
  `ValueEstimatorBase` directly stays at the top of every body.

Tests: new `TestValueEstimatorRegistry` suite (8 tests) in
`test/objectives/test_loss_module.py` covers all built-ins being
registered, string-alias resolution, error paths for bad keys, custom
estimator registration, `dispatch_value_estimator`'s unsupported-set
check, the `value_network` override path, and back-compat between
`default_value_kwargs()` and the registry.

Net: ~700 lines deleted, ~600 added.
Co-authored-by: Cursor <cursoragent@cursor.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 19, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3780

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 1e03f13 with merge base 258dfad (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 19, 2026
@github-actions github-actions Bot added Objectives Refactoring Refactoring of an existing feature labels May 19, 2026
vmoens added a commit to theap06/rl that referenced this pull request May 19, 2026
Adds first-class objectives for cooperative continuous-control MARL,
plus the supporting infrastructure they need.

## What's new

- ``torchrl.objectives.multiagent.MAPPOLoss`` -- centralised-critic,
  decentralised-actor PPO (Yu et al. 2022, https://arxiv.org/abs/2103.01955).
  Subclasses ``ClipPPOLoss``; defaults the value estimator to
  ``MultiAgentGAE``, defaults ``normalize_advantage_exclude_dims=(-2,)``,
  and optionally accepts a ``ValueNorm`` for the critic-stability trick
  from the paper (Table 13).

- ``torchrl.objectives.multiagent.IPPOLoss`` -- independent-learner
  counterpart (de Witt et al. 2020, https://arxiv.org/abs/2011.09533).
  Each agent has its own local critic; no centralised state required.

- ``torchrl.objectives.value.MultiAgentGAE`` -- ``GAE`` subclass that
  broadcasts team-shared ``reward`` / ``done`` / ``terminated`` (shape
  ``[*B, T, 1]``) across the agent dim before the vec-GAE recursion, so
  users don't have to manually replicate signals. Per-agent rewards
  pass through unchanged (competitive settings). New
  ``ValueEstimators.MAGAE`` enum entry.

- ``torchrl.modules.ValueNorm`` / ``PopArtValueNorm`` /
  ``RunningValueNorm`` -- abstract ``ValueNorm`` (an ``nn.Module``)
  with two implementations: PopArt-style EMA (van Hasselt et al. 2019,
  https://arxiv.org/abs/1809.04474) and exact Welford running stats.
  Plugs into ``MAPPOLoss(value_norm=...)`` and normalises both the
  critic target and the prediction so the MSE / smooth-L1 distance
  stays on a fixed scale as reward scales drift. Composes correctly
  with the parent ``ClipPPOLoss`` features (``clip_value``,
  ``separate_losses``, ``log_explained_variance``) via the new
  ``_critic_loss_inputs`` hook.

## Design notes

**Two classes, no centralized boolean flag.** The MAPPO / IPPO
structural difference is small but the construction recipes differ in
ways that matter (centralised critic vs. per-agent critic), so we
expose them as separate named classes; the docstring on each spells
out the recipe explicitly.

**MAGAE dispatch in plain PPO / A2C / Reinforce.** Adding
``ValueEstimators.MAGAE`` to the enum would break every parent test
that parametrises over ``list(ValueEstimators)`` unless every
``make_value_estimator`` knows the new value. We dispatch MAGAE
through ``MultiAgentGAE`` in those losses (~5 lines per file). When
the registry from pytorch#3780 lands, the explicit branches collapse to a
single ``@register_value_estimator`` decorator.

**ValueNorm placement.** Lives under ``torchrl/modules/`` rather than
``torchrl/objectives/utils/`` because it's a stateful learnable
component that participates in ``.to(device)`` / ``state_dict()``.

## Verification

- ``pytest test/objectives/test_mappo.py`` -- 25/25 passing.
  Includes regression tests for:
    * ``value_norm`` registered exactly once in ``state_dict()``
      (no duplicate ``_value_norm_module.*`` keys);
    * ``state_dict()`` save / load round-trips the running stats;
    * ``clip_value`` + ``value_norm`` still produces
      ``value_clip_fraction`` in the output;
    * ``separate_losses=True`` + ``value_norm`` correctly detaches
      so critic-loss grads do not flow into actor params;
    * no spurious annotation warnings on instantiation.
- ``pytest test/objectives/`` -- 7280 passing, 2319 skipped
  (no regressions).
- ``examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000``
  provides a minimal end-to-end VMAS Navigation smoke recipe.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Objectives Refactoring Refactoring of an existing feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant